From 903fd1b6feb1e8721ba85f849cfb9168e1ae3236 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Mon, 4 Nov 2019 16:33:37 -0500 Subject: [PATCH 01/26] save --- tfjs-core/benchmarks/index.html | 9 +++++---- tfjs-core/benchmarks/modelConfig.js | 13 +++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index dc6f78dd63d..266470d9579 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -100,6 +100,7 @@

TensorFlow.js Model Benchmark

run: (v) => { runBenchmark(); }, + backend: 'webgl', }; const modalDiv = document.getElementById('modal-msg'); @@ -340,14 +341,14 @@

TensorFlow.js Model Benchmark

await tf.ready(); gui.add(state, 'numRuns'); - gui.add(state, 'benchmark', ['', ...Object.keys(benchmarks)]); + gui.add(state, 'benchmark', Object.keys(benchmarks)); + gui.add(state, 'backend', ['webgl', 'cpu']).onChange(backend => { + tf.setBackend(backend); + }); gui.add(state, 'run'); showVersions(); await showEnvironment(); - - // Run first benchmark by default on page load. - runBenchmark(); } onPageLoad(); diff --git a/tfjs-core/benchmarks/modelConfig.js b/tfjs-core/benchmarks/modelConfig.js index ba9562a51d9..6d7349f95b5 100644 --- a/tfjs-core/benchmarks/modelConfig.js +++ b/tfjs-core/benchmarks/modelConfig.js @@ -72,6 +72,19 @@ const sentences = [ ]; const benchmarks = { + 'facemesh': { + load: async () => { + const url = + 'https://storage.googleapis.com/learnjs-data/face_mesh/model.json'; + return tf.loadGraphModel(url); + }, + predictFunc: () => { + const zeros = tf.zeros([1, 128, 128, 3]); + return model => { + return model.predict(zeros); + }; + }, + }, 'mobilenet': { load: async () => { const url = From 9bccb88d8ddb89f84d3baf6dde49eecc8e66d4c8 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 20 Nov 2019 09:51:57 -0500 Subject: [PATCH 02/26] save --- tfjs-core/benchmarks/index.html | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index 266470d9579..a797bf0a9ff 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -87,6 +87,7 @@

TensorFlow.js Model Benchmark

+ @@ -342,7 +343,7 @@

TensorFlow.js Model Benchmark

gui.add(state, 'numRuns'); gui.add(state, 'benchmark', Object.keys(benchmarks)); - gui.add(state, 'backend', ['webgl', 'cpu']).onChange(backend => { + gui.add(state, 'backend', ['webgl', 'cpu', 'wasm']).onChange(backend => { tf.setBackend(backend); }); gui.add(state, 'run'); From 1354b90ac1d6dc8b0a317d17619e5141f95e3c8f Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 20 Nov 2019 09:53:18 -0500 Subject: [PATCH 03/26] save --- tfjs-core/benchmarks/index.html | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tfjs-core/benchmarks/index.html b/tfjs-core/benchmarks/index.html index a797bf0a9ff..7a93c075826 100644 --- a/tfjs-core/benchmarks/index.html +++ b/tfjs-core/benchmarks/index.html @@ -84,7 +84,7 @@

TensorFlow.js Model Benchmark

- + From 7aeaec35965a322d2538b5c860adc1ee5beccf69 Mon Sep 17 00:00:00 2001 From: Daniel Smilkov Date: Wed, 20 Nov 2019 10:04:45 -0500 Subject: [PATCH 04/26] save --- tfjs-core/benchmarks/tf-backend-wasm.js | 1632 + tfjs-core/benchmarks/tf-core.js | 34293 ++++++++++++++++++ tfjs-core/benchmarks/tfjs-backend-wasm.wasm | Bin 0 -> 86029 bytes 3 files changed, 35925 insertions(+) create mode 100644 tfjs-core/benchmarks/tf-backend-wasm.js create mode 100644 tfjs-core/benchmarks/tf-core.js create mode 100644 tfjs-core/benchmarks/tfjs-backend-wasm.wasm diff --git a/tfjs-core/benchmarks/tf-backend-wasm.js b/tfjs-core/benchmarks/tf-backend-wasm.js new file mode 100644 index 00000000000..ac9b239ef85 --- /dev/null +++ b/tfjs-core/benchmarks/tf-backend-wasm.js @@ -0,0 +1,1632 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +(function (global, factory) { + typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports, require('@tensorflow/tfjs-core')) : + typeof define === 'function' && define.amd ? define(['exports', '@tensorflow/tfjs-core'], factory) : + (global = global || self, factory(global.tf = global.tf || {}, global.tf)); +}(this, (function (exports, tfjsCore) { 'use strict'; + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function registerUnaryKernel(kernelName) { + var wasmFunc; + function setupFunc(backend) { + wasmFunc = + backend.wasm.cwrap(kernelName, null /* void */, ['number', 'number']); + } + function kernelFunc(args) { + var backend = args.backend, x = args.inputs.x; + var xId = backend.dataIdMap.get(x.dataId).id; + var out = backend.makeOutput(x.shape, x.dtype); + var outId = backend.dataIdMap.get(out.dataId).id; + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(out.shape) === 0) { + return out; + } + wasmFunc(xId, outId); + return out; + } + tfjsCore.registerKernel({ kernelName: kernelName, backendName: 'wasm', setupFunc: setupFunc, kernelFunc: kernelFunc }); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Abs'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // This enum must align with the enum defined in cc/backend.h. + var CppDType; + (function (CppDType) { + CppDType[CppDType["float32"] = 0] = "float32"; + CppDType[CppDType["int32"] = 1] = "int32"; + CppDType[CppDType["bool"] = 2] = "bool"; + CppDType[CppDType["string"] = 3] = "string"; + CppDType[CppDType["complex64"] = 4] = "complex64"; + })(CppDType || (CppDType = {})); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function registerBinaryKernel(kernelName) { + var wasmFunc; + function setupFunc(backend) { + wasmFunc = backend.wasm.cwrap(kernelName, null /* void */, [ + 'number', + 'number', + 'number', + 'number' // out_id + ]); + } + function kernelFunc(args) { + var backend = args.backend, inputs = args.inputs; + var a = inputs.a, b = inputs.b; + var aId = backend.dataIdMap.get(a.dataId).id; + var bId = backend.dataIdMap.get(b.dataId).id; + var newShape = tfjsCore.backend_util.assertAndGetBroadcastShape(a.shape, b.shape); + var out = backend.makeOutput(newShape, a.dtype); + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(newShape) === 0) { + return out; + } + var aBroadcastDims = tfjsCore.backend_util.getBroadcastDims(a.shape, newShape); + var bBroadcastDims = tfjsCore.backend_util.getBroadcastDims(b.shape, newShape); + var loopsOverAllOfA = aBroadcastDims.every(function (v, i) { return v === i; }); + var loopsOverAllOfB = bBroadcastDims.every(function (v, i) { return v === i; }); + var outId = backend.dataIdMap.get(out.dataId).id; + if (loopsOverAllOfA && loopsOverAllOfB) { + wasmFunc(aId, bId, CppDType[a.dtype], outId); + return out; + } + else { + throw new Error('Broadcasting along inner dims is not yet supported'); + } + } + tfjsCore.registerKernel({ kernelName: kernelName, backendName: 'wasm', setupFunc: setupFunc, kernelFunc: kernelFunc }); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerBinaryKernel('Add'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmAvgPool; + function setup(backend) { + wasmAvgPool = backend.wasm.cwrap('AvgPool', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function avgPool(args) { + var inputs = args.inputs, attrs = args.attrs, backend = args.backend; + var convInfo = attrs; + var x = inputs.x; + var xId = backend.dataIdMap.get(x.dataId).id; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var padTop = convInfo.padInfo.top; + var padRight = convInfo.padInfo.right; + var padBottom = convInfo.padInfo.bottom; + var padLeft = convInfo.padInfo.left; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var inputChannels = convInfo.inChannels; + var outputChannels = convInfo.outChannels; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error("wasm backend does not support dataFormat:'" + + (convInfo.dataFormat + "'. Please use 'channelsLast'.")); + } + if (convInfo.dilationWidth !== 1 || convInfo.dilationHeight !== 1) { + throw new Error("was backend only supports average pooling with dilation = [1, 1], " + + ("got [" + convInfo.dilationHeight + ", " + convInfo.dilationWidth + "].")); + } + var out = backend.makeOutput(convInfo.outShape, 'float32'); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmAvgPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'AvgPool', + backendName: 'wasm', + setupFunc: setup, + kernelFunc: avgPool + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmFunc; + function setupFunc(backend) { + wasmFunc = backend.wasm.cwrap('AddN', null /* void */, [ + 'array', + 'number', + 'number', + 'number', + ]); + } + function addn(args) { + var inputs = args.inputs, backend = args.backend; + var out = backend.makeOutput(inputs[0].shape, inputs[0].dtype); + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(out.shape) === 0) { + return out; + } + var inputIds = inputs.map(function (x) { return backend.dataIdMap.get(x.dataId).id; }); + var inputIdsBytes = new Uint8Array(new Int32Array(inputIds).buffer); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmFunc(inputIdsBytes, inputIds.length, CppDType[out.dtype], outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'AddN', + backendName: 'wasm', + setupFunc: setupFunc, + kernelFunc: addn, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmBatchMatMul; + function setup$1(backend) { + wasmBatchMatMul = backend.wasm.cwrap('BatchMatMul', null /* void */, [ + 'number', 'number', 'number', 'number', 'number', 'number', 'number', + 'number', 'number', 'number', 'number', 'number', 'number' + ]); + } + function batchMatMul(args) { + var inputs = args.inputs, backend = args.backend, attrs = args.attrs; + var a = inputs.a, b = inputs.b; + if (a.dtype !== 'float32' || b.dtype !== 'float32') { + throw new Error("BatchMatMul for non non-float32 tensors not yet supported."); + } + var transposeA = attrs.transposeA, transposeB = attrs.transposeB; + var aId = backend.dataIdMap.get(a.dataId).id; + var bId = backend.dataIdMap.get(b.dataId).id; + var sharedDim = transposeA ? a.shape[1] : a.shape[2]; + var leftDim = transposeA ? a.shape[2] : a.shape[1]; + var rightDim = transposeB ? b.shape[1] : b.shape[2]; + var batchDim = a.shape[0]; + var aStrides = tfjsCore.util.computeStrides(a.shape); + var bStrides = tfjsCore.util.computeStrides(b.shape); + var _a = transposeA ? + [aStrides[0], 1, aStrides[1]] : + [aStrides[0], aStrides[1], 1], aBatch = _a[0], aOuterStep = _a[1], aInnerStep = _a[2]; + var _b = transposeB ? + [1, bStrides[1], bStrides[0]] : + [bStrides[1], 1, bStrides[0]], bInnerStep = _b[0], bOuterStep = _b[1], bBatch = _b[2]; + var out = backend.makeOutput([batchDim, leftDim, rightDim], a.dtype); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmBatchMatMul(aId, bId, sharedDim, leftDim, rightDim, batchDim, aBatch, aOuterStep, aInnerStep, bBatch, bOuterStep, bInnerStep, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'BatchMatMul', + backendName: 'wasm', + setupFunc: setup$1, + kernelFunc: batchMatMul + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function cast(args) { + var x = args.inputs.x, dtype = args.attrs.dtype, backend = args.backend; + var out = backend.makeOutput(x.shape, dtype); + var inVals = backend.typedArrayFromHeap(x); + var outVals = backend.typedArrayFromHeap(out); + outVals.set(inVals); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Cast', + backendName: 'wasm', + kernelFunc: cast, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmClip; + function setup$2(backend) { + wasmClip = backend.wasm.cwrap('ClipByValue', null /* void */, [ + 'number', + 'number', + 'number', + 'number' // out_id + ]); + } + function clip(args) { + var inputs = args.inputs, backend = args.backend, attrs = args.attrs; + var x = inputs.x; + var min = attrs.min, max = attrs.max; + var xId = backend.dataIdMap.get(x.dataId).id; + var out = backend.makeOutput(x.shape, 'float32'); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmClip(xId, min, max, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'ClipByValue', + backendName: 'wasm', + setupFunc: setup$2, + kernelFunc: clip + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function concat(args) { + var inputs = args.inputs, backend = args.backend, axis = args.attrs.axis; + var outShape = tfjsCore.backend_util.computeOutShape(inputs.map(function (t) { return t.shape; }), axis); + var out = backend.makeOutput(outShape, inputs[0].dtype); + var batchDim = tfjsCore.util.sizeFromShape(inputs[0].shape.slice(0, axis)); + var sumInnerDims = 0; + var innerDims = inputs.map(function (input) { + var innerDim = tfjsCore.util.sizeFromShape(input.shape.slice(axis)); + sumInnerDims += innerDim; + return innerDim; + }); + var inVals = inputs.map(function (input) { return backend.typedArrayFromHeap(input); }); + var outVals = backend.typedArrayFromHeap(out); + for (var b = 0; b < batchDim; b++) { + var outOffset = b * sumInnerDims; + for (var i = 0; i < inVals.length; i++) { + var innerDim = innerDims[i]; + var inOffset = b * innerDim; + var vals = inVals[i].subarray(inOffset, inOffset + innerDim); + outVals.set(vals, outOffset); + outOffset += innerDim; + } + } + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Concat', + backendName: 'wasm', + kernelFunc: concat, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmConv2d; + function setup$3(backend) { + wasmConv2d = backend.wasm.cwrap('Conv2D', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function conv2d(args) { + var inputs = args.inputs, attrs = args.attrs, backend = args.backend; + var convInfo = attrs; + var x = inputs.x, filter = inputs.filter; + var xId = backend.dataIdMap.get(x.dataId).id; + var filterId = backend.dataIdMap.get(filter.dataId).id; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var padTop = convInfo.padInfo.top; + var padRight = convInfo.padInfo.right; + var padBottom = convInfo.padInfo.bottom; + var padLeft = convInfo.padInfo.left; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var inputChannels = convInfo.inChannels; + var outputChannels = convInfo.outChannels; + var isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error("wasm backend Conv2D does not support dataFormat:'" + + (convInfo.dataFormat + "'. Please use 'channelsLast'.")); + } + var out = backend.makeOutput(convInfo.outShape, 'float32'); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmConv2d(xId, x.shape[0], x.shape[1], x.shape[2], filterId, filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Conv2D', + backendName: 'wasm', + setupFunc: setup$3, + kernelFunc: conv2d + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Must match enum in CropAndResize.cc + var InterpolationMethod; + (function (InterpolationMethod) { + InterpolationMethod[InterpolationMethod["bilinear"] = 0] = "bilinear"; + InterpolationMethod[InterpolationMethod["nearest"] = 1] = "nearest"; + })(InterpolationMethod || (InterpolationMethod = {})); + var wasmCropAndResize; + function setup$4(backend) { + wasmCropAndResize = backend.wasm.cwrap('CropAndResize', null /*void*/, [ + 'number', + 'number', + 'number', + 'number', + 'array', + 'number', + 'number', + 'number', + 'number', + 'number' // out id + ]); + } + function cropAndResize(args) { + var backend = args.backend, inputs = args.inputs, attrs = args.attrs; + var method = attrs.method, extrapolationValue = attrs.extrapolationValue, cropSize = attrs.cropSize; + var images = inputs.images, boxes = inputs.boxes, boxInd = inputs.boxInd; + var numBoxes = boxes.shape[0]; + var _a = cropSize, cropHeight = _a[0], cropWidth = _a[1]; + var outShape = [numBoxes, cropHeight, cropWidth, images.shape[3]]; + var imagesId = backend.dataIdMap.get(images.dataId).id; + var boxesId = backend.dataIdMap.get(boxes.dataId).id; + var boxIndId = backend.dataIdMap.get(boxInd.dataId).id; + var out = backend.makeOutput(outShape, images.dtype); + var outId = backend.dataIdMap.get(out.dataId).id; + var imagesShapeBytes = new Uint8Array(new Int32Array(images.shape).buffer); + wasmCropAndResize(imagesId, boxesId, boxIndId, numBoxes, imagesShapeBytes, cropHeight, cropWidth, InterpolationMethod[method], extrapolationValue, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'CropAndResize', + backendName: 'wasm', + setupFunc: setup$4, + kernelFunc: cropAndResize + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmDepthwiseConv2d; + function setup$5(backend) { + wasmDepthwiseConv2d = + backend.wasm.cwrap('DepthwiseConv2dNative', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function depthwiseConv2d(args) { + var inputs = args.inputs, attrs = args.attrs, backend = args.backend; + var convInfo = attrs; + var x = inputs.x, filter = inputs.filter; + var xId = backend.dataIdMap.get(x.dataId).id; + var filterId = backend.dataIdMap.get(filter.dataId).id; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var padTop = convInfo.padInfo.top; + var padRight = convInfo.padInfo.right; + var padBottom = convInfo.padInfo.bottom; + var padLeft = convInfo.padInfo.left; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var inputChannels = convInfo.inChannels; + var outputChannels = convInfo.outChannels; + var isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error("wasm backend DepthwiseConv2dNative does not support dataFormat:'" + + (convInfo.dataFormat + "'. Please use 'channelsLast'.")); + } + var out = backend.makeOutput(convInfo.outShape, 'float32'); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmDepthwiseConv2d(xId, x.shape[0], x.shape[1], x.shape[2], filterId, filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'DepthwiseConv2dNative', + backendName: 'wasm', + setupFunc: setup$5, + kernelFunc: depthwiseConv2d + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerBinaryKernel('Div'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmBatchNorm; + function setup$6(backend) { + wasmBatchNorm = backend.wasm.cwrap('FusedBatchNorm', null /* void */, ['number', 'number', 'number', 'number', 'number', 'number', 'number']); + } + function fusedBatchNorm(args) { + var backend = args.backend, inputs = args.inputs, attrs = args.attrs; + var varianceEpsilon = attrs.varianceEpsilon; + var x = inputs.x, mean = inputs.mean, variance = inputs.variance, offset = inputs.offset, scale = inputs.scale; + var xId = backend.dataIdMap.get(x.dataId).id; + var meanId = backend.dataIdMap.get(mean.dataId).id; + var varianceId = backend.dataIdMap.get(variance.dataId).id; + var offsetId = offset != null ? backend.dataIdMap.get(offset.dataId).id : -1; + var scaleId = scale != null ? backend.dataIdMap.get(scale.dataId).id : -1; + var out = backend.makeOutput(x.shape, x.dtype); + // Short-circuit zero-sized tensors. + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + var outId = backend.dataIdMap.get(out.dataId).id; + wasmBatchNorm(xId, meanId, varianceId, offsetId, scaleId, varianceEpsilon, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'BatchNormalization', + backendName: 'wasm', + setupFunc: setup$6, + kernelFunc: fusedBatchNorm + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmFusedConv2d; + function setup$7(backend) { + wasmFusedConv2d = backend.wasm.cwrap('FusedConv2D', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function fusedConv2d(args) { + var inputs = args.inputs, attrs = args.attrs, backend = args.backend; + var convInfo = attrs.convInfo, activation = attrs.activation; + if (activation !== 'linear') { + throw new Error(activation + " activation not yet supported for FusedConv2D " + + "in the wasm backend."); + } + var x = inputs.x, filter = inputs.filter, bias = inputs.bias; + var xId = backend.dataIdMap.get(x.dataId).id; + var filterId = backend.dataIdMap.get(filter.dataId).id; + var outputChannels = convInfo.outChannels; + var biasId = -1; + if (bias != null) { + var biasData = backend.dataIdMap.get(bias.dataId); + if (biasData.shape.length !== 1) { + throw new Error("FusedConv2D only supports rank-1 bias but got " + + ("rank " + biasData.shape.length + ".")); + } + if (biasData.shape[0] !== outputChannels) { + throw new Error("FusedConv2D bias shape (" + biasData.shape + ") does not " + + ("match the number of output channels (" + outputChannels + ")")); + } + biasId = biasData.id; + } + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var padTop = convInfo.padInfo.top; + var padRight = convInfo.padInfo.right; + var padBottom = convInfo.padInfo.bottom; + var padLeft = convInfo.padInfo.left; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var inputChannels = convInfo.inChannels; + var isSamePad = convInfo.padInfo.type === 'SAME' ? 1 : 0; + var batchSize = convInfo.batchSize; + var inHeight = convInfo.inHeight; + var inWidth = convInfo.inWidth; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error("wasm backend FusedConv2D does not support dataFormat:'" + + (convInfo.dataFormat + "'. Please use 'channelsLast'.")); + } + var out = backend.makeOutput(convInfo.outShape, 'float32'); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmFusedConv2d(xId, batchSize, inHeight, inWidth, filterId, filterHeight, filterWidth, biasId, padTop, padRight, padBottom, padLeft, isSamePad, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'FusedConv2D', + backendName: 'wasm', + setupFunc: setup$7, + kernelFunc: fusedConv2d + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmMax; + function setup$8(backend) { + wasmMax = + backend.wasm.cwrap('Max', null /*void*/, ['number, number, number']); + } + function max(args) { + var backend = args.backend, inputs = args.inputs, attrs = args.attrs; + var axes = attrs.axes; + var x = inputs.x; + var xId = backend.dataIdMap.get(x.dataId).id; + tfjsCore.backend_util.assertAxesAreInnerMostDims('max', axes, x.shape.length); + var _a = tfjsCore.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var reduceSize = tfjsCore.util.sizeFromShape(reduceShape); + var out = backend.makeOutput(outShape, x.dtype); + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + var outId = backend.dataIdMap.get(out.dataId).id; + wasmMax(xId, reduceSize, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Max', + backendName: 'wasm', + setupFunc: setup$8, + kernelFunc: max + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmMaxPool; + function setup$9(backend) { + wasmMaxPool = backend.wasm.cwrap('MaxPool', null /* void */, [ + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + 'number', + ]); + } + function maxPool(args) { + var inputs = args.inputs, attrs = args.attrs, backend = args.backend; + var convInfo = attrs; + var x = inputs.x; + var xId = backend.dataIdMap.get(x.dataId).id; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var padTop = convInfo.padInfo.top; + var padRight = convInfo.padInfo.right; + var padBottom = convInfo.padInfo.bottom; + var padLeft = convInfo.padInfo.left; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var inputChannels = convInfo.inChannels; + var outputChannels = convInfo.outChannels; + if (convInfo.dataFormat !== 'channelsLast') { + throw new Error("wasm backend does not support dataFormat:'" + + (convInfo.dataFormat + "'. Please use 'channelsLast'.")); + } + var out = backend.makeOutput(convInfo.outShape, 'float32'); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmMaxPool(xId, x.shape[0], x.shape[1], x.shape[2], filterHeight, filterWidth, padTop, padRight, padBottom, padLeft, dilationHeight, dilationWidth, strideHeight, strideWidth, inputChannels, outputChannels, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'MaxPool', + backendName: 'wasm', + setupFunc: setup$9, + kernelFunc: maxPool + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmMin; + function setup$a(backend) { + wasmMin = + backend.wasm.cwrap('Min', null /*void*/, ['number, number, number']); + } + function min(args) { + var backend = args.backend, inputs = args.inputs, attrs = args.attrs; + var axes = attrs.axes; + var x = inputs.x; + var xId = backend.dataIdMap.get(x.dataId).id; + tfjsCore.backend_util.assertAxesAreInnerMostDims('min', axes, x.shape.length); + var _a = tfjsCore.backend_util.computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var reduceSize = tfjsCore.util.sizeFromShape(reduceShape); + var out = backend.makeOutput(outShape, x.dtype); + if (tfjsCore.util.sizeFromShape(x.shape) === 0) { + return out; + } + var outId = backend.dataIdMap.get(out.dataId).id; + wasmMin(xId, reduceSize, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Min', + backendName: 'wasm', + setupFunc: setup$a, + kernelFunc: min + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerBinaryKernel('Mul'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmPadV2; + function setup$b(backend) { + wasmPadV2 = backend.wasm.cwrap('PadV2', null /* void */, [ + 'number', + 'array', + 'number', + 'number', + 'array', + 'number', + 'number', + ]); + } + function pad(args) { + var x = args.inputs.x, backend = args.backend, _a = args.attrs, paddings = _a.paddings, constantValue = _a.constantValue; + var outShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + x.shape[i] + p[1]; } /* afterPad */); + var xId = backend.dataIdMap.get(x.dataId).id; + var out = backend.makeOutput(outShape, x.dtype); + var outId = backend.dataIdMap.get(out.dataId).id; + var xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + var paddingsFlat = [].concat.apply([], paddings); + var paddingsBytes = new Uint8Array(new Int32Array(paddingsFlat).buffer); + wasmPadV2(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], paddingsBytes, constantValue, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'PadV2', + backendName: 'wasm', + kernelFunc: pad, + setupFunc: setup$b + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmPrelu; + function setup$c(backend) { + wasmPrelu = backend.wasm.cwrap('Prelu', null /* void */, [ + 'number', + 'number', + 'number' // out_id + ]); + } + function prelu(args) { + var inputs = args.inputs, backend = args.backend; + var x = inputs.x, alpha = inputs.alpha; + var xId = backend.dataIdMap.get(x.dataId).id; + var weightsId = backend.dataIdMap.get(alpha.dataId).id; + var out = backend.makeOutput(x.shape, 'float32'); + var outId = backend.dataIdMap.get(out.dataId).id; + wasmPrelu(xId, weightsId, outId); + return out; + } + tfjsCore.registerKernel({ + kernelName: 'Prelu', + backendName: 'wasm', + setupFunc: setup$c, + kernelFunc: prelu + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function reshape(args) { + var x = args.inputs.x, shape = args.attrs.shape; + return { dataId: x.dataId, shape: shape, dtype: x.dtype }; + } + tfjsCore.registerKernel({ + kernelName: 'Reshape', + backendName: 'wasm', + kernelFunc: reshape, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Sigmoid'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function slice(args) { + var x = args.inputs.x, _a = args.attrs, begin = _a.begin, size = _a.size, backend = args.backend; + var isContinous = tfjsCore.slice_util.isSliceContinous(x.shape, begin, size); + var xVals = backend.typedArrayFromHeap(x); + var out = backend.makeOutput(size, x.dtype); + var outVals = backend.typedArrayFromHeap(out); + var xStrides = tfjsCore.util.computeStrides(x.shape); + if (isContinous) { + var flatOffset = tfjsCore.slice_util.computeFlatOffset(begin, xStrides); + outVals.set(xVals.subarray(flatOffset, flatOffset + tfjsCore.util.sizeFromShape(size))); + return out; + } + var rank = x.shape.length; + if (rank === 2) { + slice2d(xVals, xStrides[0], outVals, begin, size); + } + else if (rank === 3) { + slice3d(xVals, xStrides[0], xStrides[1], outVals, begin, size); + } + else if (rank === 4) { + slice4d(xVals, xStrides[0], xStrides[1], xStrides[2], outVals, begin, size); + } + else { + genericSliceSlow(xVals, x, outVals, begin, size); + } + return out; + } + function slice2d(xVals, xStride, outVals, begin, size) { + var outOffset = 0; + var beginI = begin[0]; + var beginJ = begin[1]; + var endI = beginI + size[0]; + for (var i = beginI; i < endI; i++) { + var xOffset = i * xStride + beginJ; + outVals.set(xVals.subarray(xOffset, xOffset + size[1]), outOffset); + outOffset += size[1]; + } + } + function slice3d(xVals, xStride1, xStride2, outVals, begin, size) { + var outOffset = 0; + var beginI = begin[0]; + var beginJ = begin[1]; + var beginK = begin[2]; + var endI = beginI + size[0]; + var endJ = beginJ + size[1]; + for (var i = beginI; i < endI; i++) { + for (var j = beginJ; j < endJ; j++) { + var xOffset = i * xStride1 + j * xStride2 + beginK; + outVals.set(xVals.subarray(xOffset, xOffset + size[2]), outOffset); + outOffset += size[2]; + } + } + } + function slice4d(xVals, xStride1, xStride2, xStride3, outVals, begin, size) { + var outOffset = 0; + var beginI = begin[0]; + var beginJ = begin[1]; + var beginK = begin[2]; + var endI = beginI + size[0]; + var endJ = beginJ + size[1]; + var endK = beginK + size[2]; + var beginL = begin[3]; + for (var i = beginI; i < endI; i++) { + for (var j = beginJ; j < endJ; j++) { + for (var k = beginK; k < endK; k++) { + var xOffset = i * xStride1 + j * xStride2 + k * xStride3 + beginL; + outVals.set(xVals.subarray(xOffset, xOffset + size[3]), outOffset); + outOffset += size[3]; + } + } + } + } + function genericSliceSlow(xVals, xInfo, outVals, begin, size) { + var outBuf = tfjsCore.buffer(size, xInfo.dtype, outVals); + var xBuf = tfjsCore.buffer(xInfo.shape, xInfo.dtype, xVals); + for (var i = 0; i < outBuf.size; ++i) { + var loc = outBuf.indexToLoc(i); + var xLoc = loc.map(function (idx, j) { return idx + begin[j]; }); + outVals[i] = xBuf.get.apply(xBuf, xLoc); + } + } + tfjsCore.registerKernel({ + kernelName: 'Slice', + backendName: 'wasm', + kernelFunc: slice, + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerUnaryKernel('Square'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerBinaryKernel('Sub'); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var wasmTranspose; + function setup$d(backend) { + wasmTranspose = backend.wasm.cwrap('Transpose', null /* void */, [ + 'number', + 'array', + 'number', + 'number', + 'number', + 'array', + 'number', + ]); + } + function transpose(args) { + var inputs = args.inputs, backend = args.backend, attrs = args.attrs; + // Reduce any dimensions with size one. Lower-rank transpose kernel performs + // better due to simpler memory access pattern. + var _a = removeOneSizeDims(inputs.x.shape, attrs.perm), reducedShape = _a[0], perm = _a[1]; + var x = { + dataId: inputs.x.dataId, + shape: reducedShape, + dtype: inputs.x.dtype + }; + var permIsNoOp = true; + for (var i = 0; i < perm.length; i++) { + if (perm[i] !== i) { + permIsNoOp = false; + } + } + var outShape = computeOutShape(inputs.x.shape, attrs.perm); + if (permIsNoOp) { + return { dataId: x.dataId, shape: outShape, dtype: x.dtype }; + } + var out = backend.makeOutput(outShape, x.dtype); + var xId = backend.dataIdMap.get(x.dataId).id; + var outId = backend.dataIdMap.get(out.dataId).id; + var permBytes = new Uint8Array(new Int32Array(perm).buffer); + var xShapeBytes = new Uint8Array(new Int32Array(x.shape).buffer); + wasmTranspose(xId, xShapeBytes, x.shape.length, CppDType[x.dtype], outId, permBytes, perm.length); + return out; + } + function computeOutShape(inShape, perm) { + var outShape = new Array(inShape.length); + for (var i = 0; i < outShape.length; i++) { + outShape[i] = inShape[perm[i]]; + } + return outShape; + } + function removeOneSizeDims(shape, perm) { + var newShape = []; + var newPerm = []; + for (var i = 0; i < shape.length; ++i) { + if (shape[i] !== 1) { + newShape.push(shape[i]); + } + if (shape[perm[i]] !== 1) { + newPerm.push(perm[i]); + } + } + for (var i = 0; i < newPerm.length; ++i) { + var minValIdx = -1; + for (var j = 0; j < newPerm.length; ++j) { + if (newPerm[j] >= i && + (minValIdx === -1 || newPerm[minValIdx] > newPerm[j])) { + minValIdx = j; + } + } + newPerm[minValIdx] = i; + } + return [newShape, newPerm]; + } + tfjsCore.registerKernel({ + kernelName: 'Transpose', + backendName: 'wasm', + kernelFunc: transpose, + setupFunc: setup$d, + }); + + /*! ***************************************************************************** + Copyright (c) Microsoft Corporation. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of the + License at http://www.apache.org/licenses/LICENSE-2.0 + + THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED + WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, + MERCHANTABLITY OR NON-INFRINGEMENT. + + See the Apache Version 2.0 License for specific language governing permissions + and limitations under the License. + ***************************************************************************** */ + /* global Reflect, Promise */ + + var extendStatics = function(d, b) { + extendStatics = Object.setPrototypeOf || + ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || + function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; + return extendStatics(d, b); + }; + + function __extends(d, b) { + extendStatics(d, b); + function __() { this.constructor = d; } + d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); + } + + function __awaiter(thisArg, _arguments, P, generator) { + return new (P || (P = Promise))(function (resolve, reject) { + function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } + function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } + function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } + step((generator = generator.apply(thisArg, _arguments || [])).next()); + }); + } + + function __generator(thisArg, body) { + var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; + return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; + function verb(n) { return function (v) { return step([n, v]); }; } + function step(op) { + if (f) throw new TypeError("Generator is already executing."); + while (_) try { + if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; + if (y = 0, t) op = [op[0] & 2, t.value]; + switch (op[0]) { + case 0: case 1: t = op; break; + case 4: _.label++; return { value: op[1], done: false }; + case 5: _.label++; y = op[1]; op = [0]; continue; + case 7: op = _.ops.pop(); _.trys.pop(); continue; + default: + if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } + if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } + if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } + if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } + if (t[2]) _.ops.pop(); + _.trys.pop(); continue; + } + op = body.call(thisArg, _); + } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } + if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; + } + } + + function createCommonjsModule(fn, module) { + return module = { exports: {} }, fn(module, module.exports), module.exports; + } + + var tfjsBackendWasm = createCommonjsModule(function (module, exports) { + var WasmBackendModule = (function() { + var _scriptDir = typeof document !== 'undefined' && document.currentScript ? document.currentScript.src : undefined; + return ( + function(WasmBackendModule) { + WasmBackendModule = WasmBackendModule || {}; + + var Module=typeof WasmBackendModule!=="undefined"?WasmBackendModule:{};var moduleOverrides={};var key;for(key in Module){if(Module.hasOwnProperty(key)){moduleOverrides[key]=Module[key];}}var arguments_=[];var thisProgram="./this.program";var quit_=function(status,toThrow){throw toThrow};var ENVIRONMENT_IS_WEB=true;var scriptDirectory="";function locateFile(path){if(Module["locateFile"]){return Module["locateFile"](path,scriptDirectory)}return scriptDirectory+path}var readBinary;{if(document.currentScript){scriptDirectory=document.currentScript.src;}if(_scriptDir){scriptDirectory=_scriptDir;}if(scriptDirectory.indexOf("blob:")!==0){scriptDirectory=scriptDirectory.substr(0,scriptDirectory.lastIndexOf("/")+1);}else{scriptDirectory="";}}var out=Module["print"]||console.log.bind(console);var err=Module["printErr"]||console.warn.bind(console);for(key in moduleOverrides){if(moduleOverrides.hasOwnProperty(key)){Module[key]=moduleOverrides[key];}}moduleOverrides=null;if(Module["arguments"])arguments_=Module["arguments"];if(Module["thisProgram"])thisProgram=Module["thisProgram"];if(Module["quit"])quit_=Module["quit"];var wasmBinary;if(Module["wasmBinary"])wasmBinary=Module["wasmBinary"];var noExitRuntime;if(Module["noExitRuntime"])noExitRuntime=Module["noExitRuntime"];if(typeof WebAssembly!=="object"){err("no native wasm support detected");}var wasmMemory;var wasmTable=new WebAssembly.Table({"initial":71,"maximum":71+0,"element":"anyfunc"});var ABORT=false;function assert(condition,text){if(!condition){abort("Assertion failed: "+text);}}function getCFunc(ident){var func=Module["_"+ident];assert(func,"Cannot call unknown function "+ident+", make sure it is exported");return func}function ccall(ident,returnType,argTypes,args,opts){var toC={"string":function(str){var ret=0;if(str!==null&&str!==undefined&&str!==0){var len=(str.length<<2)+1;ret=stackAlloc(len);stringToUTF8(str,ret,len);}return ret},"array":function(arr){var ret=stackAlloc(arr.length);writeArrayToMemory(arr,ret);return ret}};function convertReturnValue(ret){if(returnType==="string")return UTF8ToString(ret);if(returnType==="boolean")return Boolean(ret);return ret}var func=getCFunc(ident);var cArgs=[];var stack=0;if(args){for(var i=0;i=endIdx))++endPtr;if(endPtr-idx>16&&u8Array.subarray&&UTF8Decoder){return UTF8Decoder.decode(u8Array.subarray(idx,endPtr))}else{var str="";while(idx>10,56320|ch&1023);}}}return str}function UTF8ToString(ptr,maxBytesToRead){return ptr?UTF8ArrayToString(HEAPU8,ptr,maxBytesToRead):""}function stringToUTF8Array(str,outU8Array,outIdx,maxBytesToWrite){if(!(maxBytesToWrite>0))return 0;var startIdx=outIdx;var endIdx=outIdx+maxBytesToWrite-1;for(var i=0;i=55296&&u<=57343){var u1=str.charCodeAt(++i);u=65536+((u&1023)<<10)|u1&1023;}if(u<=127){if(outIdx>=endIdx)break;outU8Array[outIdx++]=u;}else if(u<=2047){if(outIdx+1>=endIdx)break;outU8Array[outIdx++]=192|u>>6;outU8Array[outIdx++]=128|u&63;}else if(u<=65535){if(outIdx+2>=endIdx)break;outU8Array[outIdx++]=224|u>>12;outU8Array[outIdx++]=128|u>>6&63;outU8Array[outIdx++]=128|u&63;}else{if(outIdx+3>=endIdx)break;outU8Array[outIdx++]=240|u>>18;outU8Array[outIdx++]=128|u>>12&63;outU8Array[outIdx++]=128|u>>6&63;outU8Array[outIdx++]=128|u&63;}}outU8Array[outIdx]=0;return outIdx-startIdx}function stringToUTF8(str,outPtr,maxBytesToWrite){return stringToUTF8Array(str,HEAPU8,outPtr,maxBytesToWrite)}var UTF16Decoder=typeof TextDecoder!=="undefined"?new TextDecoder("utf-16le"):undefined;function writeArrayToMemory(array,buffer){HEAP8.set(array,buffer);}var WASM_PAGE_SIZE=65536;function alignUp(x,multiple){if(x%multiple>0){x+=multiple-x%multiple;}return x}var buffer,HEAP8,HEAPU8,HEAP16,HEAPU16,HEAP32,HEAPU32,HEAPF32,HEAPF64;function updateGlobalBufferAndViews(buf){buffer=buf;Module["HEAP8"]=HEAP8=new Int8Array(buf);Module["HEAP16"]=HEAP16=new Int16Array(buf);Module["HEAP32"]=HEAP32=new Int32Array(buf);Module["HEAPU8"]=HEAPU8=new Uint8Array(buf);Module["HEAPU16"]=HEAPU16=new Uint16Array(buf);Module["HEAPU32"]=HEAPU32=new Uint32Array(buf);Module["HEAPF32"]=HEAPF32=new Float32Array(buf);Module["HEAPF64"]=HEAPF64=new Float64Array(buf);}var DYNAMIC_BASE=5247600,DYNAMICTOP_PTR=4560;var INITIAL_TOTAL_MEMORY=Module["TOTAL_MEMORY"]||16777216;if(Module["wasmMemory"]){wasmMemory=Module["wasmMemory"];}else{wasmMemory=new WebAssembly.Memory({"initial":INITIAL_TOTAL_MEMORY/WASM_PAGE_SIZE});}if(wasmMemory){buffer=wasmMemory.buffer;}INITIAL_TOTAL_MEMORY=buffer.byteLength;updateGlobalBufferAndViews(buffer);HEAP32[DYNAMICTOP_PTR>>2]=DYNAMIC_BASE;function callRuntimeCallbacks(callbacks){while(callbacks.length>0){var callback=callbacks.shift();if(typeof callback=="function"){callback();continue}var func=callback.func;if(typeof func==="number"){if(callback.arg===undefined){Module["dynCall_v"](func);}else{Module["dynCall_vi"](func,callback.arg);}}else{func(callback.arg===undefined?null:callback.arg);}}}var __ATPRERUN__=[];var __ATINIT__=[];var __ATMAIN__=[];var __ATPOSTRUN__=[];function preRun(){if(Module["preRun"]){if(typeof Module["preRun"]=="function")Module["preRun"]=[Module["preRun"]];while(Module["preRun"].length){addOnPreRun(Module["preRun"].shift());}}callRuntimeCallbacks(__ATPRERUN__);}function initRuntime(){callRuntimeCallbacks(__ATINIT__);}function preMain(){callRuntimeCallbacks(__ATMAIN__);}function postRun(){if(Module["postRun"]){if(typeof Module["postRun"]=="function")Module["postRun"]=[Module["postRun"]];while(Module["postRun"].length){addOnPostRun(Module["postRun"].shift());}}callRuntimeCallbacks(__ATPOSTRUN__);}function addOnPreRun(cb){__ATPRERUN__.unshift(cb);}function addOnPostRun(cb){__ATPOSTRUN__.unshift(cb);}var Math_ceil=Math.ceil;var Math_floor=Math.floor;var runDependencies=0;var runDependencyWatcher=null;var dependenciesFulfilled=null;function addRunDependency(id){runDependencies++;if(Module["monitorRunDependencies"]){Module["monitorRunDependencies"](runDependencies);}}function removeRunDependency(id){runDependencies--;if(Module["monitorRunDependencies"]){Module["monitorRunDependencies"](runDependencies);}if(runDependencies==0){if(runDependencyWatcher!==null){clearInterval(runDependencyWatcher);runDependencyWatcher=null;}if(dependenciesFulfilled){var callback=dependenciesFulfilled;dependenciesFulfilled=null;callback();}}}Module["preloadedImages"]={};Module["preloadedAudios"]={};function abort(what){if(Module["onAbort"]){Module["onAbort"](what);}what+="";out(what);err(what);ABORT=true;what="abort("+what+"). Build with -s ASSERTIONS=1 for more info.";throw new WebAssembly.RuntimeError(what)}var dataURIPrefix="data:application/octet-stream;base64,";function isDataURI(filename){return String.prototype.startsWith?filename.startsWith(dataURIPrefix):filename.indexOf(dataURIPrefix)===0}var wasmBinaryFile="tfjs-backend-wasm.wasm";if(!isDataURI(wasmBinaryFile)){wasmBinaryFile=locateFile(wasmBinaryFile);}function getBinary(){try{if(wasmBinary){return new Uint8Array(wasmBinary)}if(readBinary){return readBinary(wasmBinaryFile)}else{throw "both async and sync fetching of the wasm failed"}}catch(err){abort(err);}}function getBinaryPromise(){if(!wasmBinary&&(ENVIRONMENT_IS_WEB)&&typeof fetch==="function"){return fetch(wasmBinaryFile,{credentials:"same-origin"}).then(function(response){if(!response["ok"]){throw "failed to load wasm binary file at '"+wasmBinaryFile+"'"}return response["arrayBuffer"]()}).catch(function(){return getBinary()})}return new Promise(function(resolve,reject){resolve(getBinary());})}function createWasm(){var info={"env":asmLibraryArg,"wasi_unstable":asmLibraryArg};function receiveInstance(instance,module){var exports=instance.exports;Module["asm"]=exports;removeRunDependency();}addRunDependency();function receiveInstantiatedSource(output){receiveInstance(output["instance"]);}function instantiateArrayBuffer(receiver){return getBinaryPromise().then(function(binary){return WebAssembly.instantiate(binary,info)}).then(receiver,function(reason){err("failed to asynchronously prepare wasm: "+reason);abort(reason);})}function instantiateAsync(){if(!wasmBinary&&typeof WebAssembly.instantiateStreaming==="function"&&!isDataURI(wasmBinaryFile)&&typeof fetch==="function"){fetch(wasmBinaryFile,{credentials:"same-origin"}).then(function(response){var result=WebAssembly.instantiateStreaming(response,info);return result.then(receiveInstantiatedSource,function(reason){err("wasm streaming compile failed: "+reason);err("falling back to ArrayBuffer instantiation");instantiateArrayBuffer(receiveInstantiatedSource);})});}else{return instantiateArrayBuffer(receiveInstantiatedSource)}}if(Module["instantiateWasm"]){try{var exports=Module["instantiateWasm"](info,receiveInstance);return exports}catch(e){err("Module.instantiateWasm callback failed with error: "+e);return false}}instantiateAsync();return {}}__ATINIT__.push({func:function(){___wasm_call_ctors();}});function _abort(){abort();}function _emscripten_memcpy_big(dest,src,num){HEAPU8.set(HEAPU8.subarray(src,src+num),dest);}function _emscripten_get_heap_size(){return HEAP8.length}function emscripten_realloc_buffer(size){try{wasmMemory.grow(size-buffer.byteLength+65535>>16);updateGlobalBufferAndViews(wasmMemory.buffer);return 1}catch(e){}}function _emscripten_resize_heap(requestedSize){var oldSize=_emscripten_get_heap_size();var PAGE_MULTIPLE=65536;var LIMIT=2147483648-PAGE_MULTIPLE;if(requestedSize>LIMIT){return false}var MIN_TOTAL_MEMORY=16777216;var newSize=Math.max(oldSize,MIN_TOTAL_MEMORY);while(newSize>2];return ret},getStr:function(){var ret=UTF8ToString(SYSCALLS.get());return ret},get64:function(){var low=SYSCALLS.get(),high=SYSCALLS.get();return low},getZero:function(){SYSCALLS.get();}};function _fd_close(fd){try{return 0}catch(e){if(typeof FS==="undefined"||!(e instanceof FS.ErrnoError))abort(e);return e.errno}}function _fd_seek(fd,offset_low,offset_high,whence,newOffset){try{return 0}catch(e){if(typeof FS==="undefined"||!(e instanceof FS.ErrnoError))abort(e);return e.errno}}function _fd_write(fd,iov,iovcnt,pnum){try{var num=0;for(var i=0;i>2];var len=HEAP32[iov+(i*8+4)>>2];for(var j=0;j>2]=num;return 0}catch(e){if(typeof FS==="undefined"||!(e instanceof FS.ErrnoError))abort(e);return e.errno}}function _roundf(d){d=+d;return d>=+0?+Math_floor(d+ +.5):+Math_ceil(d-+.5)}var asmLibraryArg={"a":_abort,"d":_emscripten_memcpy_big,"e":_emscripten_resize_heap,"f":_fd_close,"c":_fd_seek,"g":_fd_write,"memory":wasmMemory,"b":_roundf,"table":wasmTable};var asm=createWasm();Module["asm"]=asm;var ___wasm_call_ctors=Module["___wasm_call_ctors"]=function(){return Module["asm"]["h"].apply(null,arguments)};var _init=Module["_init"]=function(){return Module["asm"]["i"].apply(null,arguments)};var _register_tensor=Module["_register_tensor"]=function(){return Module["asm"]["j"].apply(null,arguments)};var _dispose_data=Module["_dispose_data"]=function(){return Module["asm"]["k"].apply(null,arguments)};var _dispose=Module["_dispose"]=function(){return Module["asm"]["l"].apply(null,arguments)};var _Abs=Module["_Abs"]=function(){return Module["asm"]["m"].apply(null,arguments)};var _Add=Module["_Add"]=function(){return Module["asm"]["n"].apply(null,arguments)};var _AddN=Module["_AddN"]=function(){return Module["asm"]["o"].apply(null,arguments)};var _AvgPool=Module["_AvgPool"]=function(){return Module["asm"]["p"].apply(null,arguments)};var _BatchMatMul=Module["_BatchMatMul"]=function(){return Module["asm"]["q"].apply(null,arguments)};var _ClipByValue=Module["_ClipByValue"]=function(){return Module["asm"]["r"].apply(null,arguments)};var _Conv2D=Module["_Conv2D"]=function(){return Module["asm"]["s"].apply(null,arguments)};var _CropAndResize=Module["_CropAndResize"]=function(){return Module["asm"]["t"].apply(null,arguments)};var _DepthwiseConv2dNative=Module["_DepthwiseConv2dNative"]=function(){return Module["asm"]["u"].apply(null,arguments)};var _Div=Module["_Div"]=function(){return Module["asm"]["v"].apply(null,arguments)};var _FusedBatchNorm=Module["_FusedBatchNorm"]=function(){return Module["asm"]["w"].apply(null,arguments)};var _FusedConv2D=Module["_FusedConv2D"]=function(){return Module["asm"]["x"].apply(null,arguments)};var _Max=Module["_Max"]=function(){return Module["asm"]["y"].apply(null,arguments)};var _MaxPool=Module["_MaxPool"]=function(){return Module["asm"]["z"].apply(null,arguments)};var _Min=Module["_Min"]=function(){return Module["asm"]["A"].apply(null,arguments)};var _Mul=Module["_Mul"]=function(){return Module["asm"]["B"].apply(null,arguments)};var _PadV2=Module["_PadV2"]=function(){return Module["asm"]["C"].apply(null,arguments)};var _Prelu=Module["_Prelu"]=function(){return Module["asm"]["D"].apply(null,arguments)};var _Sigmoid=Module["_Sigmoid"]=function(){return Module["asm"]["E"].apply(null,arguments)};var _Square=Module["_Square"]=function(){return Module["asm"]["F"].apply(null,arguments)};var _Sub=Module["_Sub"]=function(){return Module["asm"]["G"].apply(null,arguments)};var _Transpose=Module["_Transpose"]=function(){return Module["asm"]["H"].apply(null,arguments)};var _malloc=Module["_malloc"]=function(){return Module["asm"]["I"].apply(null,arguments)};var _free=Module["_free"]=function(){return Module["asm"]["J"].apply(null,arguments)};var stackSave=Module["stackSave"]=function(){return Module["asm"]["K"].apply(null,arguments)};var stackAlloc=Module["stackAlloc"]=function(){return Module["asm"]["L"].apply(null,arguments)};var stackRestore=Module["stackRestore"]=function(){return Module["asm"]["M"].apply(null,arguments)};var dynCall_vi=Module["dynCall_vi"]=function(){return Module["asm"]["N"].apply(null,arguments)};var dynCall_v=Module["dynCall_v"]=function(){return Module["asm"]["O"].apply(null,arguments)};Module["asm"]=asm;Module["cwrap"]=cwrap;var calledRun;Module["then"]=function(func){if(calledRun){func(Module);}else{var old=Module["onRuntimeInitialized"];Module["onRuntimeInitialized"]=function(){if(old)old();func(Module);};}return Module};dependenciesFulfilled=function runCaller(){if(!calledRun)run();if(!calledRun)dependenciesFulfilled=runCaller;};function run(args){if(runDependencies>0){return}preRun();if(runDependencies>0)return;function doRun(){if(calledRun)return;calledRun=true;if(ABORT)return;initRuntime();preMain();if(Module["onRuntimeInitialized"])Module["onRuntimeInitialized"]();postRun();}if(Module["setStatus"]){Module["setStatus"]("Running...");setTimeout(function(){setTimeout(function(){Module["setStatus"]("");},1);doRun();},1);}else{doRun();}}Module["run"]=run;if(Module["preInit"]){if(typeof Module["preInit"]=="function")Module["preInit"]=[Module["preInit"]];while(Module["preInit"].length>0){Module["preInit"].pop()();}}noExitRuntime=true;run(); + + + return WasmBackendModule + } + ); + })(); + module.exports = WasmBackendModule; + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var _this = undefined; + var WASM_PRIORITY = 2; + var BackendWasm = /** @class */ (function (_super) { + __extends(BackendWasm, _super); + function BackendWasm(wasm) { + var _this = _super.call(this) || this; + _this.wasm = wasm; + _this.dataIdNextNumber = 0; + _this.wasm.tfjs.init(); + _this.dataIdMap = new tfjsCore.DataStorage(_this, tfjsCore.engine()); + return _this; + } + BackendWasm.prototype.write = function (values, shape, dtype) { + var dataId = {}; + this.move(dataId, values, shape, dtype); + return dataId; + }; + BackendWasm.prototype.numDataIds = function () { + return this.dataIdMap.numDataIds(); + }; + BackendWasm.prototype.move = function (dataId, values, shape, dtype) { + var id = this.dataIdNextNumber++; + if (dtype === 'string') { + var stringBytes = values; + this.dataIdMap.set(dataId, { id: id, stringBytes: stringBytes, shape: shape, dtype: dtype, memoryOffset: null }); + return; + } + var size = tfjsCore.util.sizeFromShape(shape); + var numBytes = size * tfjsCore.util.bytesPerElement(dtype); + var memoryOffset = this.wasm._malloc(numBytes); + this.dataIdMap.set(dataId, { id: id, memoryOffset: memoryOffset, shape: shape, dtype: dtype }); + this.wasm.tfjs.registerTensor(id, size, memoryOffset); + if (values != null) { + this.wasm.HEAPU8.set(new Uint8Array(values.buffer), memoryOffset); + } + }; + BackendWasm.prototype.read = function (dataId) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.readSync(dataId)]; + }); + }); + }; + BackendWasm.prototype.readSync = function (dataId) { + var _a = this.dataIdMap.get(dataId), memoryOffset = _a.memoryOffset, dtype = _a.dtype, shape = _a.shape, stringBytes = _a.stringBytes; + if (dtype === 'string') { + return stringBytes; + } + var bytes = this.wasm.HEAPU8.slice(memoryOffset, memoryOffset + tfjsCore.util.sizeFromShape(shape) * tfjsCore.util.bytesPerElement(dtype)); + return typedArrayFromBuffer(bytes.buffer, dtype); + }; + BackendWasm.prototype.disposeData = function (dataId) { + var data = this.dataIdMap.get(dataId); + this.wasm._free(data.memoryOffset); + this.wasm.tfjs.disposeData(data.id); + this.dataIdMap.delete(dataId); + }; + BackendWasm.prototype.floatPrecision = function () { + return 32; + }; + // Returns the memory offset of a tensor. Useful for debugging and unit + // testing. + BackendWasm.prototype.getMemoryOffset = function (dataId) { + return this.dataIdMap.get(dataId).memoryOffset; + }; + BackendWasm.prototype.dispose = function () { + this.wasm.tfjs.dispose(); + this.wasm = null; + }; + BackendWasm.prototype.memory = function () { + return { unreliable: false }; + }; + BackendWasm.prototype.makeOutput = function (shape, dtype) { + var dataId = this.write(null /* values */, shape, dtype); + return { dataId: dataId, shape: shape, dtype: dtype }; + }; + BackendWasm.prototype.typedArrayFromHeap = function (_a) { + var shape = _a.shape, dtype = _a.dtype, dataId = _a.dataId; + var buffer = this.wasm.HEAPU8.buffer; + var memoryOffset = this.dataIdMap.get(dataId).memoryOffset; + var size = tfjsCore.util.sizeFromShape(shape); + switch (dtype) { + case 'float32': + return new Float32Array(buffer, memoryOffset, size); + case 'int32': + return new Int32Array(buffer, memoryOffset, size); + case 'bool': + return new Uint8Array(buffer, memoryOffset, size); + default: + throw new Error("Uknown dtype " + dtype); + } + }; + return BackendWasm; + }(tfjsCore.KernelBackend)); + tfjsCore.registerBackend('wasm', function () { return __awaiter(_this, void 0, void 0, function () { + var wasm; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, init()]; + case 1: + wasm = (_a.sent()).wasm; + return [2 /*return*/, new BackendWasm(wasm)]; + } + }); + }); }, WASM_PRIORITY); + /** + * Initializes the wasm module and creates the js <--> wasm bridge. + * + * NOTE: We wrap the wasm module in a object with property 'wasm' instead of + * returning Promise to avoid freezing Chrome (last tested in + * Chrome 76). + */ + function init() { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, new Promise(function (resolve) { + var wasm = tfjsBackendWasm(); + var voidReturnType = null; + // Using the tfjs namespace to avoid conflict with emscripten's API. + wasm.tfjs = { + init: wasm.cwrap('init', null, []), + registerTensor: wasm.cwrap('register_tensor', null, [ + 'number', + 'number', + 'number', + ]), + disposeData: wasm.cwrap('dispose_data', voidReturnType, ['number']), + dispose: wasm.cwrap('dispose', voidReturnType, []), + }; + wasm.onRuntimeInitialized = function () { return resolve({ wasm: wasm }); }; + })]; + }); + }); + } + function typedArrayFromBuffer(buffer, dtype) { + switch (dtype) { + case 'float32': + return new Float32Array(buffer); + case 'int32': + return new Int32Array(buffer); + case 'bool': + return new Uint8Array(buffer); + default: + throw new Error("Uknown dtype " + dtype); + } + } + + exports.BackendWasm = BackendWasm; + + Object.defineProperty(exports, '__esModule', { value: true }); + +}))); +//# sourceMappingURL=tf-backend-wasm.js.map diff --git a/tfjs-core/benchmarks/tf-core.js b/tfjs-core/benchmarks/tf-core.js new file mode 100644 index 00000000000..9750da18f2a --- /dev/null +++ b/tfjs-core/benchmarks/tf-core.js @@ -0,0 +1,34293 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +(function (global, factory) { + typeof exports === 'object' && typeof module !== 'undefined' ? factory(exports) : + typeof define === 'function' && define.amd ? define(['exports'], factory) : + (global = global || self, factory(global.tf = global.tf || {})); +}(this, function (exports) { 'use strict'; + + /*! ***************************************************************************** + Copyright (c) Microsoft Corporation. All rights reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use + this file except in compliance with the License. You may obtain a copy of the + License at http://www.apache.org/licenses/LICENSE-2.0 + + THIS CODE IS PROVIDED ON AN *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED + WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE, + MERCHANTABLITY OR NON-INFRINGEMENT. + + See the Apache Version 2.0 License for specific language governing permissions + and limitations under the License. + ***************************************************************************** */ + /* global Reflect, Promise */ + + var extendStatics = function(d, b) { + extendStatics = Object.setPrototypeOf || + ({ __proto__: [] } instanceof Array && function (d, b) { d.__proto__ = b; }) || + function (d, b) { for (var p in b) if (b.hasOwnProperty(p)) d[p] = b[p]; }; + return extendStatics(d, b); + }; + + function __extends(d, b) { + extendStatics(d, b); + function __() { this.constructor = d; } + d.prototype = b === null ? Object.create(b) : (__.prototype = b.prototype, new __()); + } + + function __awaiter(thisArg, _arguments, P, generator) { + return new (P || (P = Promise))(function (resolve, reject) { + function fulfilled(value) { try { step(generator.next(value)); } catch (e) { reject(e); } } + function rejected(value) { try { step(generator["throw"](value)); } catch (e) { reject(e); } } + function step(result) { result.done ? resolve(result.value) : new P(function (resolve) { resolve(result.value); }).then(fulfilled, rejected); } + step((generator = generator.apply(thisArg, _arguments || [])).next()); + }); + } + + function __generator(thisArg, body) { + var _ = { label: 0, sent: function() { if (t[0] & 1) throw t[1]; return t[1]; }, trys: [], ops: [] }, f, y, t, g; + return g = { next: verb(0), "throw": verb(1), "return": verb(2) }, typeof Symbol === "function" && (g[Symbol.iterator] = function() { return this; }), g; + function verb(n) { return function (v) { return step([n, v]); }; } + function step(op) { + if (f) throw new TypeError("Generator is already executing."); + while (_) try { + if (f = 1, y && (t = op[0] & 2 ? y["return"] : op[0] ? y["throw"] || ((t = y["return"]) && t.call(y), 0) : y.next) && !(t = t.call(y, op[1])).done) return t; + if (y = 0, t) op = [op[0] & 2, t.value]; + switch (op[0]) { + case 0: case 1: t = op; break; + case 4: _.label++; return { value: op[1], done: false }; + case 5: _.label++; y = op[1]; op = [0]; continue; + case 7: op = _.ops.pop(); _.trys.pop(); continue; + default: + if (!(t = _.trys, t = t.length > 0 && t[t.length - 1]) && (op[0] === 6 || op[0] === 2)) { _ = 0; continue; } + if (op[0] === 3 && (!t || (op[1] > t[0] && op[1] < t[3]))) { _.label = op[1]; break; } + if (op[0] === 6 && _.label < t[1]) { _.label = t[1]; t = op; break; } + if (t && _.label < t[2]) { _.label = t[2]; _.ops.push(op); break; } + if (t[2]) _.ops.pop(); + _.trys.pop(); continue; + } + op = body.call(thisArg, _); + } catch (e) { op = [6, e]; y = 0; } finally { f = t = 0; } + if (op[0] & 5) throw op[1]; return { value: op[0] ? op[1] : void 0, done: true }; + } + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Expects flags from URL in the format ?tfjsflags=FLAG1:1,FLAG2:true. + var TENSORFLOWJS_FLAGS_PREFIX = 'tfjsflags'; + /** + * The environment contains evaluated flags as well as the registered platform. + * This is always used as a global singleton and can be retrieved with + * `tf.env()`. + */ + /** @doc {heading: 'Environment'} */ + var Environment = /** @class */ (function () { + // tslint:disable-next-line: no-any + function Environment(global) { + this.global = global; + this.flags = {}; + this.flagRegistry = {}; + this.urlFlags = {}; + this.populateURLFlags(); + } + Environment.prototype.setPlatform = function (platformName, platform) { + if (this.platform != null) { + console.warn("Platform " + this.platformName + " has already been set. " + + ("Overwriting the platform with " + platform + ".")); + } + this.platformName = platformName; + this.platform = platform; + }; + Environment.prototype.registerFlag = function (flagName, evaluationFn, setHook) { + this.flagRegistry[flagName] = { evaluationFn: evaluationFn, setHook: setHook }; + // Override the flag value from the URL. This has to happen here because the + // environment is initialized before flags get registered. + if (this.urlFlags[flagName] != null) { + var flagValue = this.urlFlags[flagName]; + console.warn("Setting feature override from URL " + flagName + ": " + flagValue + "."); + this.set(flagName, flagValue); + } + }; + Environment.prototype.get = function (flagName) { + if (flagName in this.flags) { + return this.flags[flagName]; + } + this.flags[flagName] = this.evaluateFlag(flagName); + return this.flags[flagName]; + }; + Environment.prototype.getNumber = function (flagName) { + return this.get(flagName); + }; + Environment.prototype.getBool = function (flagName) { + return this.get(flagName); + }; + Environment.prototype.getFlags = function () { + return this.flags; + }; + Object.defineProperty(Environment.prototype, "features", { + // For backwards compatibility. + get: function () { + return this.flags; + }, + enumerable: true, + configurable: true + }); + Environment.prototype.set = function (flagName, value) { + if (this.flagRegistry[flagName] == null) { + throw new Error("Cannot set flag " + flagName + " as it has not been registered."); + } + this.flags[flagName] = value; + if (this.flagRegistry[flagName].setHook != null) { + this.flagRegistry[flagName].setHook(value); + } + }; + Environment.prototype.evaluateFlag = function (flagName) { + if (this.flagRegistry[flagName] == null) { + throw new Error("Cannot evaluate flag '" + flagName + "': no evaluation function found."); + } + return this.flagRegistry[flagName].evaluationFn(); + }; + Environment.prototype.setFlags = function (flags) { + this.flags = Object.assign({}, flags); + }; + Environment.prototype.reset = function () { + this.flags = {}; + this.urlFlags = {}; + this.populateURLFlags(); + }; + Environment.prototype.populateURLFlags = function () { + var _this = this; + if (typeof this.global === 'undefined' || + typeof this.global.location === 'undefined' || + typeof this.global.location.search === 'undefined') { + return; + } + var urlParams = getQueryParams(this.global.location.search); + if (TENSORFLOWJS_FLAGS_PREFIX in urlParams) { + var keyValues = urlParams[TENSORFLOWJS_FLAGS_PREFIX].split(','); + keyValues.forEach(function (keyValue) { + var _a = keyValue.split(':'), key = _a[0], value = _a[1]; + _this.urlFlags[key] = parseValue(key, value); + }); + } + }; + return Environment; + }()); + function getQueryParams(queryString) { + var params = {}; + queryString.replace(/[?&]([^=?&]+)(?:=([^&]*))?/g, function (s) { + var t = []; + for (var _i = 1; _i < arguments.length; _i++) { + t[_i - 1] = arguments[_i]; + } + decodeParam(params, t[0], t[1]); + return t.join('='); + }); + return params; + } + function decodeParam(params, name, value) { + params[decodeURIComponent(name)] = decodeURIComponent(value || ''); + } + function parseValue(flagName, value) { + value = value.toLowerCase(); + if (value === 'true' || value === 'false') { + return value === 'true'; + } + else if ("" + +value === value) { + return +value; + } + throw new Error("Could not parse value flag value " + value + " for flag " + flagName + "."); + } + /** + * Returns the current environment (a global singleton). + * + * The environment object contains the evaluated feature values as well as the + * active platform. + */ + /** @doc {heading: 'Environment'} */ + function env() { + return exports.ENV; + } + exports.ENV = null; + function setEnvironmentGlobal(environment) { + exports.ENV = environment; + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var kernelRegistry = new Map(); + /** + * Returns the kernel function (code) associated with the provided names. + * + * @param kernelName The official name of the kernel. + * @param backendName The official name of the backend. + */ + function getKernel(kernelName, backendName) { + var key = makeKey(kernelName, backendName); + return kernelRegistry.get(key); + } + function getKernelsForBackend(backendName) { + var it = kernelRegistry.entries(); + var result = []; + while (true) { + var _a = it.next(), done = _a.done, value = _a.value; + if (done) { + break; + } + var key = value[0], config = value[1]; + var backend = key.split('_')[0]; + if (backend === backendName) { + result.push(config); + } + } + return result; + } + /** + * Registers the function (forward pass) for the kernel in a global registry. + * + * @param config A config object with the following properties: + * - `kernelName` The official name of the kernel. + * - `backendName` The official name of the backend. + * - `kernelFunc` The function to run during the forward pass of the kernel. + * - `setupFunc` Optional. Gets called once, after the backend initializes. + * - `disposeFunc` Optional. Gets called once, right before the backend is + * disposed. + */ + function registerKernel(config) { + var kernelName = config.kernelName, backendName = config.backendName; + var key = makeKey(kernelName, backendName); + if (kernelRegistry.has(key)) { + throw new Error("The kernel '" + kernelName + "' for backend " + + ("'" + backendName + "' is already registered")); + } + kernelRegistry.set(key, config); + } + /** + * Removes the kernel function from the registry. + * + * @param kernelName The official name of the kernel. + * @param backendName The official name of the backend. + * + */ + function unregisterKernel(kernelName, backendName) { + var key = makeKey(kernelName, backendName); + if (!kernelRegistry.has(key)) { + throw new Error("The kernel '" + kernelName + "' for backend " + + ("'" + backendName + "' is not registered")); + } + kernelRegistry.delete(key); + } + function makeKey(kernelName, backendName) { + return backendName + "_" + kernelName; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Shuffles the array in-place using Fisher-Yates algorithm. + * + * ```js + * const a = [1, 2, 3, 4, 5]; + * tf.util.shuffle(a); + * console.log(a); + * ``` + * + * @param array The array to shuffle in-place. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + // tslint:disable-next-line:no-any + function shuffle(array) { + var counter = array.length; + var temp = 0; + var index = 0; + // While there are elements in the array + while (counter > 0) { + // Pick a random index + index = (Math.random() * counter) | 0; + // Decrease counter by 1 + counter--; + // And swap the last element with it + temp = array[counter]; + array[counter] = array[index]; + array[index] = temp; + } + } + /** Clamps a value to a specified range. */ + function clamp(min, x, max) { + return Math.max(min, Math.min(x, max)); + } + function nearestLargerEven(val) { + return val % 2 === 0 ? val : val + 1; + } + function sum(arr) { + var sum = 0; + for (var i = 0; i < arr.length; i++) { + sum += arr[i]; + } + return sum; + } + /** + * Returns a sample from a uniform [a, b) distribution. + * + * @param a The minimum support (inclusive). + * @param b The maximum support (exclusive). + * @return A pseudorandom number on the half-open interval [a,b). + */ + function randUniform(a, b) { + var r = Math.random(); + return (b * r) + (1 - r) * a; + } + /** Returns the squared Euclidean distance between two vectors. */ + function distSquared(a, b) { + var result = 0; + for (var i = 0; i < a.length; i++) { + var diff = Number(a[i]) - Number(b[i]); + result += diff * diff; + } + return result; + } + /** + * Asserts that the expression is true. Otherwise throws an error with the + * provided message. + * + * ```js + * const x = 2; + * tf.util.assert(x === 2, 'x is not 2'); + * ``` + * + * @param expr The expression to assert (as a boolean). + * @param msg A function that returns the message to report when throwing an + * error. We use a function for performance reasons. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function assert(expr, msg) { + if (!expr) { + throw new Error(typeof msg === 'string' ? msg : msg()); + } + } + function assertShapesMatch(shapeA, shapeB, errorMessagePrefix) { + if (errorMessagePrefix === void 0) { errorMessagePrefix = ''; } + assert(arraysEqual(shapeA, shapeB), function () { return errorMessagePrefix + (" Shapes " + shapeA + " and " + shapeB + " must match"); }); + } + function assertNonNull(a) { + assert(a != null, function () { return "The input to the tensor constructor must be a non-null value."; }); + } + // NOTE: We explicitly type out what T extends instead of any so that + // util.flatten on a nested array of number doesn't try to infer T as a + // number[][], causing us to explicitly type util.flatten(). + /** + * Flattens an arbitrarily nested array. + * + * ```js + * const a = [[1, 2], [3, 4], [5, [6, [7]]]]; + * const flat = tf.util.flatten(a); + * console.log(flat); + * ``` + * + * @param arr The nested array to flatten. + * @param result The destination array which holds the elements. + * @param skipTypedArray If true, avoids flattening the typed arrays. Defaults + * to false. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function flatten(arr, result, skipTypedArray) { + if (result === void 0) { result = []; } + if (skipTypedArray === void 0) { skipTypedArray = false; } + if (result == null) { + result = []; + } + if (Array.isArray(arr) || isTypedArray(arr) && !skipTypedArray) { + for (var i = 0; i < arr.length; ++i) { + flatten(arr[i], result, skipTypedArray); + } + } + else { + result.push(arr); + } + return result; + } + /** + * Returns the size (number of elements) of the tensor given its shape. + * + * ```js + * const shape = [3, 4, 2]; + * const size = tf.util.sizeFromShape(shape); + * console.log(size); + * ``` + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function sizeFromShape(shape) { + if (shape.length === 0) { + // Scalar. + return 1; + } + var size = shape[0]; + for (var i = 1; i < shape.length; i++) { + size *= shape[i]; + } + return size; + } + function isScalarShape(shape) { + return shape.length === 0; + } + function arraysEqual(n1, n2) { + if (n1 === n2) { + return true; + } + if (n1 == null || n2 == null) { + return false; + } + if (n1.length !== n2.length) { + return false; + } + for (var i = 0; i < n1.length; i++) { + if (n1[i] !== n2[i]) { + return false; + } + } + return true; + } + function isInt(a) { + return a % 1 === 0; + } + function tanh(x) { + // tslint:disable-next-line:no-any + if (Math.tanh != null) { + // tslint:disable-next-line:no-any + return Math.tanh(x); + } + if (x === Infinity) { + return 1; + } + else if (x === -Infinity) { + return -1; + } + else { + var e2x = Math.exp(2 * x); + return (e2x - 1) / (e2x + 1); + } + } + function sizeToSquarishShape(size) { + var width = Math.ceil(Math.sqrt(size)); + return [width, Math.ceil(size / width)]; + } + /** + * Creates a new array with randomized indicies to a given quantity. + * + * ```js + * const randomTen = tf.util.createShuffledIndices(10); + * console.log(randomTen); + * ``` + * + * @param number Quantity of how many shuffled indicies to create. + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function createShuffledIndices(n) { + var shuffledIndices = new Uint32Array(n); + for (var i = 0; i < n; ++i) { + shuffledIndices[i] = i; + } + shuffle(shuffledIndices); + return shuffledIndices; + } + function rightPad(a, size) { + if (size <= a.length) { + return a; + } + return a + ' '.repeat(size - a.length); + } + function repeatedTry(checkFn, delayFn, maxCounter) { + if (delayFn === void 0) { delayFn = function (counter) { return 0; }; } + return new Promise(function (resolve, reject) { + var tryCount = 0; + var tryFn = function () { + if (checkFn()) { + resolve(); + return; + } + tryCount++; + var nextBackoff = delayFn(tryCount); + if (maxCounter != null && tryCount >= maxCounter) { + reject(); + return; + } + setTimeout(tryFn, nextBackoff); + }; + tryFn(); + }); + } + /** + * Given the full size of the array and a shape that may contain -1 as the + * implicit dimension, returns the inferred shape where -1 is replaced. + * E.g. For shape=[2, -1, 3] and size=24, it will return [2, 4, 3]. + * + * @param shape The shape, which may contain -1 in some dimension. + * @param size The full size (number of elements) of the array. + * @return The inferred shape where -1 is replaced with the inferred size. + */ + function inferFromImplicitShape(shape, size) { + var shapeProd = 1; + var implicitIdx = -1; + for (var i = 0; i < shape.length; ++i) { + if (shape[i] >= 0) { + shapeProd *= shape[i]; + } + else if (shape[i] === -1) { + if (implicitIdx !== -1) { + throw Error("Shapes can only have 1 implicit size. " + + ("Found -1 at dim " + implicitIdx + " and dim " + i)); + } + implicitIdx = i; + } + else if (shape[i] < 0) { + throw Error("Shapes can not be < 0. Found " + shape[i] + " at dim " + i); + } + } + if (implicitIdx === -1) { + if (size > 0 && size !== shapeProd) { + throw Error("Size(" + size + ") must match the product of shape " + shape); + } + return shape; + } + if (shapeProd === 0) { + throw Error("Cannot infer the missing size in [" + shape + "] when " + + "there are 0 elements"); + } + if (size % shapeProd !== 0) { + throw Error("The implicit shape can't be a fractional number. " + + ("Got " + size + " / " + shapeProd)); + } + var newShape = shape.slice(); + newShape[implicitIdx] = size / shapeProd; + return newShape; + } + function parseAxisParam(axis, shape) { + var rank = shape.length; + // Normalize input + axis = axis == null ? shape.map(function (s, i) { return i; }) : [].concat(axis); + // Check for valid range + assert(axis.every(function (ax) { return ax >= -rank && ax < rank; }), function () { + return "All values in axis param must be in range [-" + rank + ", " + rank + ") but " + + ("got axis " + axis); + }); + // Check for only integers + assert(axis.every(function (ax) { return isInt(ax); }), function () { return "All values in axis param must be integers but " + + ("got axis " + axis); }); + // Handle negative axis. + return axis.map(function (a) { return a < 0 ? rank + a : a; }); + } + /** Reduces the shape by removing all dimensions of shape 1. */ + function squeezeShape(shape, axis) { + var newShape = []; + var keptDims = []; + var isEmptyArray = axis != null && Array.isArray(axis) && axis.length === 0; + var axes = (axis == null || isEmptyArray) ? + null : + parseAxisParam(axis, shape).sort(); + var j = 0; + for (var i = 0; i < shape.length; ++i) { + if (axes != null) { + if (axes[j] === i && shape[i] !== 1) { + throw new Error("Can't squeeze axis " + i + " since its dim '" + shape[i] + "' is not 1"); + } + if ((axes[j] == null || axes[j] > i) && shape[i] === 1) { + newShape.push(shape[i]); + keptDims.push(i); + } + if (axes[j] <= i) { + j++; + } + } + if (shape[i] !== 1) { + newShape.push(shape[i]); + keptDims.push(i); + } + } + return { newShape: newShape, keptDims: keptDims }; + } + function getTypedArrayFromDType(dtype, size) { + var values = null; + if (dtype == null || dtype === 'float32') { + values = new Float32Array(size); + } + else if (dtype === 'int32') { + values = new Int32Array(size); + } + else if (dtype === 'bool') { + values = new Uint8Array(size); + } + else { + throw new Error("Unknown data type " + dtype); + } + return values; + } + function getArrayFromDType(dtype, size) { + var values = null; + if (dtype == null || dtype === 'float32') { + values = new Float32Array(size); + } + else if (dtype === 'int32') { + values = new Int32Array(size); + } + else if (dtype === 'bool') { + values = new Uint8Array(size); + } + else if (dtype === 'string') { + values = new Array(size); + } + else { + throw new Error("Unknown data type " + dtype); + } + return values; + } + function checkConversionForErrors(vals, dtype) { + for (var i = 0; i < vals.length; i++) { + var num = vals[i]; + if (isNaN(num) || !isFinite(num)) { + throw Error("A tensor of type " + dtype + " being uploaded contains " + num + "."); + } + } + } + /** Returns true if the dtype is valid. */ + function isValidDtype(dtype) { + return dtype === 'bool' || dtype === 'complex64' || dtype === 'float32' || + dtype === 'int32' || dtype === 'string'; + } + /** + * Returns true if the new type can't encode the old type without loss of + * precision. + */ + function hasEncodingLoss(oldType, newType) { + if (newType === 'complex64') { + return false; + } + if (newType === 'float32' && oldType !== 'complex64') { + return false; + } + if (newType === 'int32' && oldType !== 'float32' && oldType !== 'complex64') { + return false; + } + if (newType === 'bool' && oldType === 'bool') { + return false; + } + return true; + } + function isTypedArray(a) { + return a instanceof Float32Array || a instanceof Int32Array || + a instanceof Uint8Array; + } + function bytesPerElement(dtype) { + if (dtype === 'float32' || dtype === 'int32') { + return 4; + } + else if (dtype === 'complex64') { + return 8; + } + else if (dtype === 'bool') { + return 1; + } + else { + throw new Error("Unknown dtype " + dtype); + } + } + /** + * Returns the approximate number of bytes allocated in the string array - 2 + * bytes per character. Computing the exact bytes for a native string in JS is + * not possible since it depends on the encoding of the html page that serves + * the website. + */ + function bytesFromStringArray(arr) { + if (arr == null) { + return 0; + } + var bytes = 0; + arr.forEach(function (x) { return bytes += x.length; }); + return bytes; + } + /** Returns true if the value is a string. */ + function isString(value) { + return typeof value === 'string' || value instanceof String; + } + function isBoolean(value) { + return typeof value === 'boolean'; + } + function isNumber(value) { + return typeof value === 'number'; + } + function inferDtype(values) { + if (Array.isArray(values)) { + return inferDtype(values[0]); + } + if (values instanceof Float32Array) { + return 'float32'; + } + else if (values instanceof Int32Array || values instanceof Uint8Array) { + return 'int32'; + } + else if (isNumber(values)) { + return 'float32'; + } + else if (isString(values)) { + return 'string'; + } + else if (isBoolean(values)) { + return 'bool'; + } + return 'float32'; + } + function isFunction(f) { + return !!(f && f.constructor && f.call && f.apply); + } + function nearestDivisor(size, start) { + for (var i = start; i < size; ++i) { + if (size % i === 0) { + return i; + } + } + return size; + } + function computeStrides(shape) { + var rank = shape.length; + if (rank < 2) { + return []; + } + // Last dimension has implicit stride of 1, thus having D-1 (instead of D) + // strides. + var strides = new Array(rank - 1); + strides[rank - 2] = shape[rank - 1]; + for (var i = rank - 3; i >= 0; --i) { + strides[i] = strides[i + 1] * shape[i + 1]; + } + return strides; + } + function toTypedArray(a, dtype, debugMode) { + if (dtype === 'string') { + throw new Error('Cannot convert a string[] to a TypedArray'); + } + if (Array.isArray(a)) { + a = flatten(a); + } + if (debugMode) { + checkConversionForErrors(a, dtype); + } + if (noConversionNeeded(a, dtype)) { + return a; + } + if (dtype == null || dtype === 'float32' || dtype === 'complex64') { + return new Float32Array(a); + } + else if (dtype === 'int32') { + return new Int32Array(a); + } + else if (dtype === 'bool') { + var bool = new Uint8Array(a.length); + for (var i = 0; i < bool.length; ++i) { + if (Math.round(a[i]) !== 0) { + bool[i] = 1; + } + } + return bool; + } + else { + throw new Error("Unknown data type " + dtype); + } + } + function createNestedArray(offset, shape, a) { + var ret = new Array(); + if (shape.length === 1) { + var d = shape[0]; + for (var i = 0; i < d; i++) { + ret[i] = a[offset + i]; + } + } + else { + var d = shape[0]; + var rest = shape.slice(1); + var len = rest.reduce(function (acc, c) { return acc * c; }); + for (var i = 0; i < d; i++) { + ret[i] = createNestedArray(offset + i * len, rest, a); + } + } + return ret; + } + // Provide a nested array of TypedArray in given shape. + function toNestedArray(shape, a) { + if (shape.length === 0) { + // Scalar type should return a single number. + return a[0]; + } + var size = shape.reduce(function (acc, c) { return acc * c; }); + if (size === 0) { + // A tensor with shape zero should be turned into empty list. + return []; + } + if (size !== a.length) { + throw new Error("[" + shape + "] does not match the input size."); + } + return createNestedArray(0, shape, a); + } + function noConversionNeeded(a, dtype) { + return (a instanceof Float32Array && dtype === 'float32') || + (a instanceof Int32Array && dtype === 'int32') || + (a instanceof Uint8Array && dtype === 'bool'); + } + function makeOnesTypedArray(size, dtype) { + var array = makeZerosTypedArray(size, dtype); + for (var i = 0; i < array.length; i++) { + array[i] = 1; + } + return array; + } + function makeZerosTypedArray(size, dtype) { + if (dtype == null || dtype === 'float32' || dtype === 'complex64') { + return new Float32Array(size); + } + else if (dtype === 'int32') { + return new Int32Array(size); + } + else if (dtype === 'bool') { + return new Uint8Array(size); + } + else { + throw new Error("Unknown data type " + dtype); + } + } + /** + * Returns the current high-resolution time in milliseconds relative to an + * arbitrary time in the past. It works across different platforms (node.js, + * browsers). + * + * ```js + * console.log(tf.util.now()); + * ``` + */ + /** @doc {heading: 'Util', namespace: 'util'} */ + function now() { + return env().platform.now(); + } + function assertNonNegativeIntegerDimensions(shape) { + shape.forEach(function (dimSize) { + assert(Number.isInteger(dimSize) && dimSize >= 0, function () { + return "Tensor must have a shape comprised of positive integers but got " + + ("shape [" + shape + "]."); + }); + }); + } + /** + * Returns a platform-specific implementation of + * [`fetch`](https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API). + * + * If `fetch` is defined on the global object (`window`, `process`, etc.), + * `tf.util.fetch` returns that function. + * + * If not, `tf.util.fetch` returns a platform-specific solution. + * + * ```js + * const resource = await tf.util.fetch('https://unpkg.com/@tensorflow/tfjs'); + * // handle response + * ``` + */ + /** @doc {heading: 'Util'} */ + function fetch$1(path, requestInits) { + return env().platform.fetch(path, requestInits); + } + /** + * Encodes the provided string into bytes using the provided encoding scheme. + * + * @param s The string to encode. + * @param encoding The encoding scheme. Defaults to utf-8. + * + */ + /** @doc {heading: 'Util'} */ + function encodeString(s, encoding) { + if (encoding === void 0) { encoding = 'utf-8'; } + encoding = encoding || 'utf-8'; + return env().platform.encode(s, encoding); + } + /** + * Decodes the provided bytes into a string using the provided encoding scheme. + * @param bytes The bytes to decode. + * + * @param encoding The encoding scheme. Defaults to utf-8. + */ + /** @doc {heading: 'Util'} */ + function decodeString(bytes, encoding) { + if (encoding === void 0) { encoding = 'utf-8'; } + encoding = encoding || 'utf-8'; + return env().platform.decode(bytes, encoding); + } + + var util = /*#__PURE__*/Object.freeze({ + shuffle: shuffle, + clamp: clamp, + nearestLargerEven: nearestLargerEven, + sum: sum, + randUniform: randUniform, + distSquared: distSquared, + assert: assert, + assertShapesMatch: assertShapesMatch, + assertNonNull: assertNonNull, + flatten: flatten, + sizeFromShape: sizeFromShape, + isScalarShape: isScalarShape, + arraysEqual: arraysEqual, + isInt: isInt, + tanh: tanh, + sizeToSquarishShape: sizeToSquarishShape, + createShuffledIndices: createShuffledIndices, + rightPad: rightPad, + repeatedTry: repeatedTry, + inferFromImplicitShape: inferFromImplicitShape, + parseAxisParam: parseAxisParam, + squeezeShape: squeezeShape, + getTypedArrayFromDType: getTypedArrayFromDType, + getArrayFromDType: getArrayFromDType, + checkConversionForErrors: checkConversionForErrors, + isValidDtype: isValidDtype, + hasEncodingLoss: hasEncodingLoss, + isTypedArray: isTypedArray, + bytesPerElement: bytesPerElement, + bytesFromStringArray: bytesFromStringArray, + isString: isString, + isBoolean: isBoolean, + isNumber: isNumber, + inferDtype: inferDtype, + isFunction: isFunction, + nearestDivisor: nearestDivisor, + computeStrides: computeStrides, + toTypedArray: toTypedArray, + toNestedArray: toNestedArray, + makeOnesTypedArray: makeOnesTypedArray, + makeZerosTypedArray: makeZerosTypedArray, + now: now, + assertNonNegativeIntegerDimensions: assertNonNegativeIntegerDimensions, + fetch: fetch$1, + encodeString: encodeString, + decodeString: decodeString + }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var Profiler = /** @class */ (function () { + function Profiler(backendTimer, logger) { + this.backendTimer = backendTimer; + this.logger = logger; + if (logger == null) { + this.logger = new Logger(); + } + } + Profiler.prototype.profileKernel = function (kernelName, inputs, f) { + var _this = this; + var outputs; + var holdResultWrapperFn = function () { + outputs = f(); + }; + var timer = this.backendTimer.time(holdResultWrapperFn); + outputs.forEach(function (r) { + // Dangling promise here because we don't want to propagate up + // asynchronicity. + r.data().then(function (vals) { + checkComputationForErrors(vals, r.dtype, kernelName); + timer.then(function (timing) { + var extraInfo = ''; + if (timing.getExtraProfileInfo != null) { + extraInfo = timing.getExtraProfileInfo(); + } + _this.logger.logKernelProfile(kernelName, r, vals, timing.kernelMs, inputs, extraInfo); + }); + }); + }); + return outputs; + }; + return Profiler; + }()); + function checkComputationForErrors(vals, dtype, kernelName) { + if (dtype !== 'float32') { + // Only floating point computations will generate NaN values + return false; + } + for (var i = 0; i < vals.length; i++) { + var num = vals[i]; + if (isNaN(num) || !isFinite(num)) { + // Throwing custom exception so behavior is testable. + console.warn("Found " + num + " in the result of '" + kernelName + "'"); + return true; + } + } + return false; + } + var Logger = /** @class */ (function () { + function Logger() { + } + Logger.prototype.logKernelProfile = function (name, result, vals, timeMs, inputs, extraInfo) { + var time = rightPad(timeMs + "ms", 9); + var paddedName = rightPad(name, 25); + var rank = result.rank; + var size = result.size; + var shape = rightPad(result.shape.toString(), 14); + var inputShapesDescription = ''; + for (var name_1 in inputs) { + var inputShape = inputs[name_1].shape; + var inputRank = inputShape.length; + inputShapesDescription += + name_1 + ": " + inputRank + "D " + (inputRank > 0 ? inputShape : '') + " "; + } + console.log("%c" + paddedName + "\t%c" + time + "\t%c" + rank + "D " + shape + "\t%c" + size + "\t%c" + inputShapesDescription + "\t%c" + extraInfo, 'font-weight:bold', 'color:red', 'color:blue', 'color: orange', 'color: green', 'color: steelblue'); + }; + return Logger; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes a list of TapeNodes that connect x to y, filtering everything else + * out and preserving the order of the original tape elements. + * + * @param tape The tape elements to filter. + * @param xs The input Tensors. + * @param y The output Tensor. + */ + function getFilteredNodesXToY(tape, xs, y) { + // Forward pass to compute all the nodes and Tensors that are transitively a + // function of x. + var tensorsFromX = {}; + var nodesFromX = {}; + for (var i = 0; i < xs.length; i++) { + tensorsFromX[xs[i].id] = true; + } + for (var i = 0; i < tape.length; i++) { + var node = tape[i]; + var nodeInputs = node.inputs; + for (var inputName in nodeInputs) { + var input = nodeInputs[inputName]; + var anyInputFromX = false; + for (var j = 0; j < xs.length; j++) { + if (tensorsFromX[input.id]) { + node.outputs.forEach(function (output) { return tensorsFromX[output.id] = true; }); + anyInputFromX = true; + nodesFromX[node.id] = true; + break; + } + } + if (anyInputFromX) { + break; + } + } + } + // Backward pass to find all of the nodes and Tensors that lead to y. + var tensorsLeadToY = {}; + tensorsLeadToY[y.id] = true; + var nodesToY = {}; + for (var i = tape.length - 1; i >= 0; i--) { + var node = tape[i]; + var nodeInputs = node.inputs; + // If any of the outputs lead to y, mark all of the inputs as leading to y. + for (var j = 0; j < node.outputs.length; j++) { + if (tensorsLeadToY[node.outputs[j].id]) { + for (var inputName in nodeInputs) { + tensorsLeadToY[nodeInputs[inputName].id] = true; + nodesToY[node.id] = true; + } + break; + } + } + } + // Return the paths that come from x and lead to y. + var filteredTape = []; + for (var i = 0; i < tape.length; i++) { + var node = tape[i]; + if (nodesFromX[node.id] && nodesToY[node.id]) { + // Prune the inputs from the node that aren't a function of x. + var prunedInputs = {}; + for (var inputName in node.inputs) { + var nodeInput = node.inputs[inputName]; + if (tensorsFromX[nodeInput.id]) { + prunedInputs[inputName] = nodeInput; + } + } + // Copy the node and overwrite inputsAndArgs to the pruned version. + var prunedNode = Object.assign({}, node); + prunedNode.inputs = prunedInputs; + prunedNode.outputs = node.outputs; + filteredTape.push(prunedNode); + } + } + return filteredTape; + } + /** + * Backpropagate gradients through the filtered TapeNodes. + * + * @param tensorAccumulatedGradientMap A map of Tensor to its gradient. This map + * is mutated by this method. + * @param filteredTape The filtered TapeNodes to backprop through. + */ + function backpropagateGradients(tensorAccumulatedGradientMap, filteredTape, tidy) { + var _loop_1 = function (i) { + var node = filteredTape[i]; + var dys = []; + node.outputs.forEach(function (o) { + var gradTensor = tensorAccumulatedGradientMap[o.id]; + if (gradTensor != null) { + dys.push(gradTensor); + } + else { + // This particular output is not in the back-propagation subgraph, so it + // does not affect the final output, thus we put null for its dy. + dys.push(null); + } + }); + if (node.gradient == null) { + throw new Error("Cannot compute gradient: gradient function not found " + + ("for " + node.name + ".")); + } + // Backprop dy through this node and accumulate gradients over the inputs. + var inputGradients = node.gradient(dys); + var _loop_2 = function (inputName) { + if (!(inputName in inputGradients)) { + throw new Error("Cannot backprop through input " + inputName + ". " + + ("Available gradients found: " + Object.keys(inputGradients) + ".")); + } + // Call the gradient function. + var dx = tidy(function () { return inputGradients[inputName](); }); + if (dx.dtype !== 'float32') { + throw new Error("Error in gradient for op " + node.name + ". The gradient of input " + + (inputName + " must have 'float32' dtype, but has '" + dx.dtype + "'")); + } + var x = node.inputs[inputName]; + if (!arraysEqual(dx.shape, x.shape)) { + throw new Error("Error in gradient for op " + node.name + ". The gradient of input " + + ("'" + inputName + "' has shape '" + dx.shape + "', which does not match ") + + ("the shape of the input '" + x.shape + "'")); + } + if (tensorAccumulatedGradientMap[x.id] == null) { + tensorAccumulatedGradientMap[x.id] = dx; + } + else { + var curGradient = tensorAccumulatedGradientMap[x.id]; + tensorAccumulatedGradientMap[x.id] = curGradient.add(dx); + curGradient.dispose(); + } + }; + for (var inputName in node.inputs) { + _loop_2(inputName); + } + }; + // Walk the tape backward and keep a map of Tensor to its gradient. + for (var i = filteredTape.length - 1; i >= 0; i--) { + _loop_1(i); + } + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Maximum number of values before we decide to show ellipsis. + var FORMAT_LIMIT_NUM_VALS = 20; + // Number of first and last values to show when displaying a, b,...,y, z. + var FORMAT_NUM_FIRST_LAST_VALS = 3; + // Number of significant digits to show. + var FORMAT_NUM_SIG_DIGITS = 7; + function tensorToString(vals, shape, dtype, verbose) { + var strides = computeStrides(shape); + var padPerCol = computeMaxSizePerColumn(vals, shape, dtype, strides); + var rank = shape.length; + var valsLines = subTensorToString(vals, shape, dtype, strides, padPerCol); + var lines = ['Tensor']; + if (verbose) { + lines.push(" dtype: " + dtype); + lines.push(" rank: " + rank); + lines.push(" shape: [" + shape + "]"); + lines.push(" values:"); + } + lines.push(valsLines.map(function (l) { return ' ' + l; }).join('\n')); + return lines.join('\n'); + } + function computeMaxSizePerColumn(vals, shape, dtype, strides) { + var n = sizeFromShape(shape); + var numCols = strides[strides.length - 1]; + var padPerCol = new Array(numCols).fill(0); + var rank = shape.length; + var valuesOrTuples = dtype === 'complex64' ? createComplexTuples(vals) : vals; + if (rank > 1) { + for (var row = 0; row < n / numCols; row++) { + var offset = row * numCols; + for (var j = 0; j < numCols; j++) { + padPerCol[j] = Math.max(padPerCol[j], valToString(valuesOrTuples[offset + j], 0, dtype).length); + } + } + } + return padPerCol; + } + function valToString(val, pad, dtype) { + var valStr; + if (Array.isArray(val)) { + valStr = parseFloat(val[0].toFixed(FORMAT_NUM_SIG_DIGITS)) + " + " + + (parseFloat(val[1].toFixed(FORMAT_NUM_SIG_DIGITS)) + "j"); + } + else if (isString(val)) { + valStr = "'" + val + "'"; + } + else if (dtype === 'bool') { + valStr = boolNumToString(val); + } + else { + valStr = parseFloat(val.toFixed(FORMAT_NUM_SIG_DIGITS)).toString(); + } + return rightPad(valStr, pad); + } + function boolNumToString(v) { + return v === 0 ? 'false' : 'true'; + } + function subTensorToString(vals, shape, dtype, strides, padPerCol, isLast) { + if (isLast === void 0) { isLast = true; } + var storagePerElement = dtype === 'complex64' ? 2 : 1; + var size = shape[0]; + var rank = shape.length; + if (rank === 0) { + if (dtype === 'complex64') { + var complexTuple = createComplexTuples(vals); + return [valToString(complexTuple[0], 0, dtype)]; + } + if (dtype === 'bool') { + return [boolNumToString(vals[0])]; + } + return [vals[0].toString()]; + } + if (rank === 1) { + if (size > FORMAT_LIMIT_NUM_VALS) { + var firstValsSize = FORMAT_NUM_FIRST_LAST_VALS * storagePerElement; + var firstVals = Array.from(vals.slice(0, firstValsSize)); + var lastVals = Array.from(vals.slice((size - FORMAT_NUM_FIRST_LAST_VALS) * storagePerElement, size * storagePerElement)); + if (dtype === 'complex64') { + firstVals = createComplexTuples(firstVals); + lastVals = createComplexTuples(lastVals); + } + return [ + '[' + + firstVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) + .join(', ') + + ', ..., ' + + lastVals + .map(function (x, i) { return valToString(x, padPerCol[size - FORMAT_NUM_FIRST_LAST_VALS + i], dtype); }) + .join(', ') + + ']' + ]; + } + var displayVals = dtype === 'complex64' ? createComplexTuples(vals) : + Array.from(vals); + return [ + '[' + + displayVals.map(function (x, i) { return valToString(x, padPerCol[i], dtype); }) + .join(', ') + + ']' + ]; + } + // The array is rank 2 or more. + var subshape = shape.slice(1); + var substrides = strides.slice(1); + var stride = strides[0] * storagePerElement; + var lines = []; + if (size > FORMAT_LIMIT_NUM_VALS) { + for (var i = 0; i < FORMAT_NUM_FIRST_LAST_VALS; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, false /* isLast */)); + } + lines.push('...'); + for (var i = size - FORMAT_NUM_FIRST_LAST_VALS; i < size; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); + } + } + else { + for (var i = 0; i < size; i++) { + var start = i * stride; + var end = start + stride; + lines.push.apply(lines, subTensorToString(vals.slice(start, end), subshape, dtype, substrides, padPerCol, i === size - 1 /* isLast */)); + } + } + var sep = rank === 2 ? ',' : ''; + lines[0] = '[' + lines[0] + sep; + for (var i = 1; i < lines.length - 1; i++) { + lines[i] = ' ' + lines[i] + sep; + } + var newLineSep = ',\n'; + for (var i = 2; i < rank; i++) { + newLineSep += '\n'; + } + lines[lines.length - 1] = + ' ' + lines[lines.length - 1] + ']' + (isLast ? '' : newLineSep); + return lines; + } + function createComplexTuples(vals) { + var complexTuples = []; + for (var i = 0; i < vals.length; i += 2) { + complexTuples.push([vals[i], vals[i + 1]]); + } + return complexTuples; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * A mutable object, similar to `tf.Tensor`, that allows users to set values + * at locations before converting to an immutable `tf.Tensor`. + * + * See `tf.buffer` for creating a tensor buffer. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + var TensorBuffer = /** @class */ (function () { + function TensorBuffer(shape, dtype, values) { + var _this = this; + this.dtype = dtype; + this.shape = shape.slice(); + this.size = sizeFromShape(shape); + if (values != null) { + var n_1 = values.length; + assert(n_1 === this.size, function () { return "Length of values '" + n_1 + "' does not match the size " + + ("inferred by the shape '" + _this.size + "'."); }); + } + if (dtype === 'complex64') { + throw new Error("complex64 dtype TensorBuffers are not supported. Please create " + + "a TensorBuffer for the real and imaginary parts separately and " + + "call tf.complex(real, imag)."); + } + this.values = values || getArrayFromDType(dtype, this.size); + this.strides = computeStrides(shape); + } + /** + * Sets a value in the buffer at a given location. + * + * @param value The value to set. + * @param locs The location indices. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + TensorBuffer.prototype.set = function (value) { + var _this = this; + var locs = []; + for (var _i = 1; _i < arguments.length; _i++) { + locs[_i - 1] = arguments[_i]; + } + if (locs.length === 0) { + locs = [0]; + } + assert(locs.length === this.rank, function () { return "The number of provided coordinates (" + locs.length + ") must " + + ("match the rank (" + _this.rank + ")"); }); + var index = this.locToIndex(locs); + this.values[index] = value; + }; + /** + * Returns the value in the buffer at the provided location. + * + * @param locs The location indices. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + TensorBuffer.prototype.get = function () { + var locs = []; + for (var _i = 0; _i < arguments.length; _i++) { + locs[_i] = arguments[_i]; + } + if (locs.length === 0) { + locs = [0]; + } + var i = 0; + for (var _a = 0, locs_1 = locs; _a < locs_1.length; _a++) { + var loc = locs_1[_a]; + if (loc < 0 || loc >= this.shape[i]) { + var msg = "Requested out of range element at " + locs + ". " + + (" Buffer shape=" + this.shape); + throw new Error(msg); + } + i++; + } + var index = locs[locs.length - 1]; + for (var i_1 = 0; i_1 < locs.length - 1; ++i_1) { + index += this.strides[i_1] * locs[i_1]; + } + return this.values[index]; + }; + TensorBuffer.prototype.locToIndex = function (locs) { + if (this.rank === 0) { + return 0; + } + else if (this.rank === 1) { + return locs[0]; + } + var index = locs[locs.length - 1]; + for (var i = 0; i < locs.length - 1; ++i) { + index += this.strides[i] * locs[i]; + } + return index; + }; + TensorBuffer.prototype.indexToLoc = function (index) { + if (this.rank === 0) { + return []; + } + else if (this.rank === 1) { + return [index]; + } + var locs = new Array(this.shape.length); + for (var i = 0; i < locs.length - 1; ++i) { + locs[i] = Math.floor(index / this.strides[i]); + index -= locs[i] * this.strides[i]; + } + locs[locs.length - 1] = index; + return locs; + }; + Object.defineProperty(TensorBuffer.prototype, "rank", { + get: function () { + return this.shape.length; + }, + enumerable: true, + configurable: true + }); + /** + * Creates an immutable `tf.Tensor` object from the buffer. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + TensorBuffer.prototype.toTensor = function () { + return trackerFn().makeTensor(this.values, this.shape, this.dtype); + }; + return TensorBuffer; + }()); + // For tracking tensor creation and disposal. + var trackerFn = null; + // Used by chaining methods to call into ops. + var opHandler = null; + // Used to warn about deprecated methods. + var deprecationWarningFn = null; + /** + * An external consumer can register itself as the tensor tracker. This way + * the Tensor class can notify the tracker for every tensor created and + * disposed. + */ + function setTensorTracker(fn) { + trackerFn = fn; + } + /** + * An external consumer can register itself as the op handler. This way the + * Tensor class can have chaining methods that call into ops via the op + * handler. + */ + function setOpHandler(handler) { + opHandler = handler; + } + /** + * Sets the deprecation warning function to be used by this file. This way the + * Tensor class can be a leaf but still use the environment. + */ + function setDeprecationWarningFn(fn) { + deprecationWarningFn = fn; + } + /** + * A `tf.Tensor` object represents an immutable, multidimensional array of + * numbers that has a shape and a data type. + * + * See `tf.tensor` for details on how to create a `tf.Tensor`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + var Tensor = /** @class */ (function () { + function Tensor(shape, dtype, dataId, id) { + /** Whether this tensor has been globally kept. */ + this.kept = false; + this.isDisposedInternal = false; + this.shape = shape.slice(); + this.dtype = dtype || 'float32'; + this.size = sizeFromShape(shape); + this.strides = computeStrides(shape); + this.dataId = dataId; + this.id = id; + this.rankType = (this.rank < 5 ? this.rank.toString() : 'higher'); + } + /** Flatten a Tensor to a 1D array. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.flatten = function () { + this.throwIfDisposed(); + return this.as1D(); + }; + /** Converts a size-1 `tf.Tensor` to a `tf.Scalar`. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.asScalar = function () { + this.throwIfDisposed(); + assert(this.size === 1, function () { return 'The array must have only 1 element.'; }); + return this.reshape([]); + }; + /** Converts a `tf.Tensor` to a `tf.Tensor1D`. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as1D = function () { + this.throwIfDisposed(); + return this.reshape([this.size]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor2D`. + * + * @param rows Number of rows in `tf.Tensor2D`. + * @param columns Number of columns in `tf.Tensor2D`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as2D = function (rows, columns) { + this.throwIfDisposed(); + return this.reshape([rows, columns]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor3D`. + * + * @param rows Number of rows in `tf.Tensor3D`. + * @param columns Number of columns in `tf.Tensor3D`. + * @param depth Depth of `tf.Tensor3D`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as3D = function (rows, columns, depth) { + this.throwIfDisposed(); + return this.reshape([rows, columns, depth]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor4D`. + * + * @param rows Number of rows in `tf.Tensor4D`. + * @param columns Number of columns in `tf.Tensor4D`. + * @param depth Depth of `tf.Tensor4D`. + * @param depth2 4th dimension of `tf.Tensor4D`. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as4D = function (rows, columns, depth, depth2) { + this.throwIfDisposed(); + return this.reshape([rows, columns, depth, depth2]); + }; + /** + * Converts a `tf.Tensor` to a `tf.Tensor5D`. + * + * @param rows Number of rows in `tf.Tensor5D`. + * @param columns Number of columns in `tf.Tensor5D`. + * @param depth Depth of `tf.Tensor5D`. + * @param depth2 4th dimension of `tf.Tensor5D`. + * @param depth3 5th dimension of 'tf.Tensor5D' + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.as5D = function (rows, columns, depth, depth2, depth3) { + this.throwIfDisposed(); + return this.reshape([rows, columns, depth, depth2, depth3]); + }; + /** + * Casts a `tf.Tensor` to a specified dtype. + * + * @param dtype Data-type to cast the tensor to. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.asType = function (dtype) { + this.throwIfDisposed(); + return opHandler.cast(this, dtype); + }; + Object.defineProperty(Tensor.prototype, "rank", { + get: function () { + return this.shape.length; + }, + enumerable: true, + configurable: true + }); + /** + * Returns a promise of `tf.TensorBuffer` that holds the underlying data. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.buffer = function () { + return __awaiter(this, void 0, void 0, function () { + var vals; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.data()]; + case 1: + vals = _a.sent(); + return [2 /*return*/, opHandler.buffer(this.shape, this.dtype, vals)]; + } + }); + }); + }; + /** Returns a `tf.TensorBuffer` that holds the underlying data. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.bufferSync = function () { + return opHandler.buffer(this.shape, this.dtype, this.dataSync()); + }; + /** + * Returns the tensor data as a nested array. The transfer of data is done + * asynchronously. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.array = function () { + return __awaiter(this, void 0, void 0, function () { + var vals; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.data()]; + case 1: + vals = _a.sent(); + return [2 /*return*/, toNestedArray(this.shape, vals)]; + } + }); + }); + }; + /** + * Returns the tensor data as a nested array. The transfer of data is done + * synchronously. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.arraySync = function () { + return toNestedArray(this.shape, this.dataSync()); + }; + /** + * Asynchronously downloads the values from the `tf.Tensor`. Returns a + * promise of `TypedArray` that resolves when the computation has finished. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.data = function () { + return __awaiter(this, void 0, void 0, function () { + var data, bytes; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + this.throwIfDisposed(); + data = trackerFn().read(this.dataId); + if (!(this.dtype === 'string')) return [3 /*break*/, 2]; + return [4 /*yield*/, data]; + case 1: + bytes = _a.sent(); + try { + return [2 /*return*/, bytes.map(function (b) { return decodeString(b); })]; + } + catch (_b) { + throw new Error('Failed to decode the string bytes into utf-8. ' + + 'To get the original bytes, call tensor.bytes().'); + } + _a.label = 2; + case 2: return [2 /*return*/, data]; + } + }); + }); + }; + /** + * Synchronously downloads the values from the `tf.Tensor`. This blocks the + * UI thread until the values are ready, which can cause performance issues. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.dataSync = function () { + this.throwIfDisposed(); + var data = trackerFn().readSync(this.dataId); + if (this.dtype === 'string') { + try { + return data.map(function (b) { return decodeString(b); }); + } + catch (_a) { + throw new Error('Failed to decode the string bytes into utf-8. ' + + 'To get the original bytes, call tensor.bytes().'); + } + } + return data; + }; + /** Returns the underlying bytes of the tensor's data. */ + Tensor.prototype.bytes = function () { + return __awaiter(this, void 0, void 0, function () { + var data; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + this.throwIfDisposed(); + return [4 /*yield*/, trackerFn().read(this.dataId)]; + case 1: + data = _a.sent(); + if (this.dtype === 'string') { + return [2 /*return*/, data]; + } + else { + return [2 /*return*/, new Uint8Array(data.buffer)]; + } + return [2 /*return*/]; + } + }); + }); + }; + /** + * Disposes `tf.Tensor` from memory. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.dispose = function () { + if (this.isDisposed) { + return; + } + trackerFn().disposeTensor(this); + this.isDisposedInternal = true; + }; + Object.defineProperty(Tensor.prototype, "isDisposed", { + get: function () { + return this.isDisposedInternal; + }, + enumerable: true, + configurable: true + }); + Tensor.prototype.throwIfDisposed = function () { + if (this.isDisposed) { + throw new Error("Tensor is disposed."); + } + }; + /** Casts the array to type `float32` */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toFloat = function () { + return this.asType('float32'); + }; + /** Casts the array to type `int32` */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toInt = function () { + return this.asType('int32'); + }; + /** Casts the array to type `bool` */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toBool = function () { + return this.asType('bool'); + }; + /** + * Prints the `tf.Tensor`. See `tf.print` for details. + * + * @param verbose Whether to print verbose information about the tensor, + * including dtype and size. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.print = function (verbose) { + if (verbose === void 0) { verbose = false; } + return opHandler.print(this, verbose); + }; + /** + * Reshapes the tensor into the provided shape. + * See `tf.reshape` for more details. + * + * @param newShape An array of integers defining the output tensor shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.reshape = function (newShape) { + this.throwIfDisposed(); + return opHandler.reshape(this, newShape); + }; + /** + * Reshapes the tensor into the shape of the provided tensor. + * + * @param x The tensor of required shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.reshapeAs = function (x) { + this.throwIfDisposed(); + return this.reshape(x.shape); + }; + /** + * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension + * into the tensor's shape. See `tf.expandDims` for details. + * + * @param axis The dimension index at which to insert shape of 1. Defaults to + * 0 (the first dimension). + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.expandDims = function (axis) { + if (axis === void 0) { axis = 0; } + return opHandler.expandDims(this, axis); + }; + /** + * Returns the cumulative sum of the `tf.Tensor` along `axis`. + * + * @param axis The axis along which to sum. Optional. Defaults to 0. + * @param exclusive Whether to perform exclusive cumulative sum. Defaults to + * false. If set to true then the sum of each tensor entry does not + * include its own value, but only the values previous to it along the + * specified axis. + * @param reverse Whether to sum in the opposite direction. Defaults to + * false. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.cumsum = function (axis, exclusive, reverse) { + if (axis === void 0) { axis = 0; } + if (exclusive === void 0) { exclusive = false; } + if (reverse === void 0) { reverse = false; } + return opHandler.cumsum(this, axis, exclusive, reverse); + }; + /** + * Returns a `tf.Tensor` with dimensions of size 1 removed from the shape. + * See `tf.squeeze` for more details. + * + * @param axis A list of numbers. If specified, only squeezes the + * dimensions listed. The dimension index starts at 0. It is an error to + * squeeze a dimension that is not 1. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.squeeze = function (axis) { + this.throwIfDisposed(); + return opHandler.squeeze(this, axis); + }; + /** Returns a copy of the tensor. See `tf.clone` for details. */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.clone = function () { + this.throwIfDisposed(); + return opHandler.clone(this); + }; + Tensor.prototype.oneHot = function (depth, onValue, offValue) { + this.throwIfDisposed(); + return opHandler.oneHot(this, depth, onValue, offValue); + }; + /** + * Returns a human-readable description of the tensor. Useful for logging. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Tensor.prototype.toString = function (verbose) { + if (verbose === void 0) { verbose = false; } + var vals = this.dataSync(); + return tensorToString(vals, this.shape, this.dtype, verbose); + }; + // Below is chain API that is not exposed to docs to avoid repetition. To + // expose a method, move it above this comment and add @doc and jsdoc. + Tensor.prototype.tile = function (reps) { + this.throwIfDisposed(); + return opHandler.tile(this, reps); + }; + Tensor.prototype.gather = function (indices, axis) { + if (axis === void 0) { axis = 0; } + this.throwIfDisposed(); + return opHandler.gather(this, indices, axis); + }; + Tensor.prototype.matMul = function (b, transposeA, transposeB) { + if (transposeA === void 0) { transposeA = false; } + if (transposeB === void 0) { transposeB = false; } + this.throwIfDisposed(); + return opHandler.matMul(this, b, transposeA, transposeB); + }; + Tensor.prototype.dot = function (b) { + this.throwIfDisposed(); + return opHandler.dot(this, b); + }; + Tensor.prototype.norm = function (ord, axis, keepDims) { + if (ord === void 0) { ord = 'euclidean'; } + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.norm(this, ord, axis, keepDims); + }; + Tensor.prototype.slice = function (begin, size) { + this.throwIfDisposed(); + return opHandler.slice(this, begin, size); + }; + Tensor.prototype.reverse = function (axis) { + this.throwIfDisposed(); + return opHandler.reverse(this, axis); + }; + Tensor.prototype.concat = function (x, axis) { + if (axis === void 0) { axis = 0; } + this.throwIfDisposed(); + if (x instanceof Tensor) { + x = [x]; + } + return opHandler.concat([this].concat(x), axis); + }; + Tensor.prototype.split = function (numOrSizeSplits, axis) { + if (axis === void 0) { axis = 0; } + this.throwIfDisposed(); + return opHandler.split(this, numOrSizeSplits, axis); + }; + Tensor.prototype.stack = function (x, axis) { + if (axis === void 0) { axis = 0; } + return opHandler.stack([this, x], axis); + }; + Tensor.prototype.unstack = function (axis) { + if (axis === void 0) { axis = 0; } + return opHandler.unstack(this, axis); + }; + Tensor.prototype.pad = function (paddings, constantValue) { + if (constantValue === void 0) { constantValue = 0; } + return opHandler.pad(this, paddings, constantValue); + }; + /** + * @deprecated Use `tf.batchNorm` instead, and note the positional argument + * change of scale, offset, and varianceEpsilon. + */ + Tensor.prototype.batchNormalization = function (mean, variance, varianceEpsilon, scale, offset) { + if (varianceEpsilon === void 0) { varianceEpsilon = .001; } + deprecationWarningFn('tf.batchNormalization() is going away. ' + + 'Use tf.batchNorm() instead, and note the positional argument change ' + + 'of scale, offset, and varianceEpsilon'); + return this.batchNorm(mean, variance, offset, scale, varianceEpsilon); + }; + Tensor.prototype.batchNorm = function (mean, variance, offset, scale, varianceEpsilon) { + if (varianceEpsilon === void 0) { varianceEpsilon = .001; } + this.throwIfDisposed(); + return opHandler.batchNorm(this, mean, variance, offset, scale, varianceEpsilon); + }; + // Reduction ops. + Tensor.prototype.all = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.all(this, axis, keepDims); + }; + Tensor.prototype.any = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.any(this, axis, keepDims); + }; + Tensor.prototype.logSumExp = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.logSumExp(this, axis, keepDims); + }; + Tensor.prototype.sum = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.sum(this, axis, keepDims); + }; + Tensor.prototype.prod = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.prod(this, axis, keepDims); + }; + Tensor.prototype.mean = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.mean(this, axis, keepDims); + }; + Tensor.prototype.min = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.min(this, axis, keepDims); + }; + Tensor.prototype.max = function (axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + this.throwIfDisposed(); + return opHandler.max(this, axis, keepDims); + }; + Tensor.prototype.argMin = function (axis) { + if (axis === void 0) { axis = null; } + this.throwIfDisposed(); + return opHandler.argMin(this, axis); + }; + Tensor.prototype.argMax = function (axis) { + if (axis === void 0) { axis = null; } + this.throwIfDisposed(); + return opHandler.argMax(this, axis); + }; + // Transformations + Tensor.prototype.cast = function (dtype) { + this.throwIfDisposed(); + return opHandler.cast(this, dtype); + }; + // Binary ops. + Tensor.prototype.add = function (x) { + this.throwIfDisposed(); + return opHandler.add(this, x); + }; + Tensor.prototype.addStrict = function (x) { + this.throwIfDisposed(); + return opHandler.addStrict(this, x); + }; + Tensor.prototype.atan2 = function (x) { + this.throwIfDisposed(); + return opHandler.atan2(this, x); + }; + Tensor.prototype.sub = function (x) { + this.throwIfDisposed(); + return opHandler.sub(this, x); + }; + Tensor.prototype.subStrict = function (x) { + this.throwIfDisposed(); + return opHandler.subStrict(this, x); + }; + Tensor.prototype.pow = function (exp) { + this.throwIfDisposed(); + return opHandler.pow(this, exp); + }; + Tensor.prototype.powStrict = function (exp) { + this.throwIfDisposed(); + return opHandler.powStrict(this, exp); + }; + Tensor.prototype.mul = function (x) { + this.throwIfDisposed(); + return opHandler.mul(this, x); + }; + Tensor.prototype.mulStrict = function (x) { + this.throwIfDisposed(); + return opHandler.mulStrict(this, x); + }; + Tensor.prototype.div = function (x) { + this.throwIfDisposed(); + return opHandler.div(this, x); + }; + Tensor.prototype.divNoNan = function (x) { + this.throwIfDisposed(); + return opHandler.divNoNan(this, x); + }; + Tensor.prototype.floorDiv = function (x) { + this.throwIfDisposed(); + return opHandler.floorDiv(this, x); + }; + Tensor.prototype.divStrict = function (x) { + this.throwIfDisposed(); + return opHandler.divStrict(this, x); + }; + Tensor.prototype.minimum = function (x) { + this.throwIfDisposed(); + return opHandler.minimum(this, x); + }; + Tensor.prototype.minimumStrict = function (x) { + this.throwIfDisposed(); + return opHandler.minimumStrict(this, x); + }; + Tensor.prototype.maximum = function (x) { + this.throwIfDisposed(); + return opHandler.maximum(this, x); + }; + Tensor.prototype.maximumStrict = function (x) { + this.throwIfDisposed(); + return opHandler.maximumStrict(this, x); + }; + Tensor.prototype.mod = function (x) { + this.throwIfDisposed(); + return opHandler.mod(this, x); + }; + Tensor.prototype.modStrict = function (x) { + this.throwIfDisposed(); + return opHandler.modStrict(this, x); + }; + Tensor.prototype.squaredDifference = function (x) { + this.throwIfDisposed(); + return opHandler.squaredDifference(this, x); + }; + Tensor.prototype.squaredDifferenceStrict = function (x) { + this.throwIfDisposed(); + return opHandler.squaredDifferenceStrict(this, x); + }; + Tensor.prototype.transpose = function (perm) { + this.throwIfDisposed(); + return opHandler.transpose(this, perm); + }; + // Compare ops. + Tensor.prototype.notEqual = function (x) { + this.throwIfDisposed(); + return opHandler.notEqual(this, x); + }; + Tensor.prototype.notEqualStrict = function (x) { + this.throwIfDisposed(); + return opHandler.notEqualStrict(this, x); + }; + Tensor.prototype.less = function (x) { + this.throwIfDisposed(); + return opHandler.less(this, x); + }; + Tensor.prototype.lessStrict = function (x) { + this.throwIfDisposed(); + return opHandler.lessStrict(this, x); + }; + Tensor.prototype.equal = function (x) { + this.throwIfDisposed(); + return opHandler.equal(this, x); + }; + Tensor.prototype.equalStrict = function (x) { + this.throwIfDisposed(); + return opHandler.equalStrict(this, x); + }; + Tensor.prototype.lessEqual = function (x) { + this.throwIfDisposed(); + return opHandler.lessEqual(this, x); + }; + Tensor.prototype.lessEqualStrict = function (x) { + this.throwIfDisposed(); + return opHandler.lessEqualStrict(this, x); + }; + Tensor.prototype.greater = function (x) { + this.throwIfDisposed(); + return opHandler.greater(this, x); + }; + Tensor.prototype.greaterStrict = function (x) { + this.throwIfDisposed(); + return opHandler.greaterStrict(this, x); + }; + Tensor.prototype.greaterEqual = function (x) { + this.throwIfDisposed(); + return opHandler.greaterEqual(this, x); + }; + Tensor.prototype.greaterEqualStrict = function (x) { + this.throwIfDisposed(); + return opHandler.greaterEqualStrict(this, x); + }; + // Compare ops. + Tensor.prototype.logicalAnd = function (x) { + this.throwIfDisposed(); + return opHandler.logicalAnd(this, x); + }; + Tensor.prototype.logicalOr = function (x) { + this.throwIfDisposed(); + return opHandler.logicalOr(this, x); + }; + Tensor.prototype.logicalNot = function () { + this.throwIfDisposed(); + return opHandler.logicalNot(this); + }; + Tensor.prototype.logicalXor = function (x) { + this.throwIfDisposed(); + return opHandler.logicalXor(this, x); + }; + Tensor.prototype.where = function (condition, x) { + this.throwIfDisposed(); + return opHandler.where(condition, this, x); + }; + // Unary ops. + Tensor.prototype.neg = function () { + this.throwIfDisposed(); + return opHandler.neg(this); + }; + Tensor.prototype.ceil = function () { + this.throwIfDisposed(); + return opHandler.ceil(this); + }; + Tensor.prototype.floor = function () { + this.throwIfDisposed(); + return opHandler.floor(this); + }; + Tensor.prototype.sign = function () { + this.throwIfDisposed(); + return opHandler.sign(this); + }; + Tensor.prototype.isNaN = function () { + this.throwIfDisposed(); + return opHandler.isNaN(this); + }; + Tensor.prototype.isInf = function () { + this.throwIfDisposed(); + return opHandler.isInf(this); + }; + Tensor.prototype.isFinite = function () { + this.throwIfDisposed(); + return opHandler.isFinite(this); + }; + Tensor.prototype.exp = function () { + this.throwIfDisposed(); + return opHandler.exp(this); + }; + Tensor.prototype.expm1 = function () { + this.throwIfDisposed(); + return opHandler.expm1(this); + }; + Tensor.prototype.log = function () { + this.throwIfDisposed(); + return opHandler.log(this); + }; + Tensor.prototype.log1p = function () { + this.throwIfDisposed(); + return opHandler.log1p(this); + }; + Tensor.prototype.sqrt = function () { + this.throwIfDisposed(); + return opHandler.sqrt(this); + }; + Tensor.prototype.rsqrt = function () { + this.throwIfDisposed(); + return opHandler.rsqrt(this); + }; + Tensor.prototype.square = function () { + this.throwIfDisposed(); + return opHandler.square(this); + }; + Tensor.prototype.reciprocal = function () { + this.throwIfDisposed(); + return opHandler.reciprocal(this); + }; + Tensor.prototype.abs = function () { + this.throwIfDisposed(); + return opHandler.abs(this); + }; + Tensor.prototype.clipByValue = function (min, max) { + this.throwIfDisposed(); + return opHandler.clipByValue(this, min, max); + }; + Tensor.prototype.relu = function () { + this.throwIfDisposed(); + return opHandler.relu(this); + }; + Tensor.prototype.relu6 = function () { + this.throwIfDisposed(); + return opHandler.relu6(this); + }; + Tensor.prototype.elu = function () { + this.throwIfDisposed(); + return opHandler.elu(this); + }; + Tensor.prototype.selu = function () { + this.throwIfDisposed(); + return opHandler.selu(this); + }; + Tensor.prototype.leakyRelu = function (alpha) { + if (alpha === void 0) { alpha = 0.2; } + this.throwIfDisposed(); + return opHandler.leakyRelu(this, alpha); + }; + Tensor.prototype.prelu = function (alpha) { + this.throwIfDisposed(); + return opHandler.prelu(this, alpha); + }; + Tensor.prototype.sigmoid = function () { + this.throwIfDisposed(); + return opHandler.sigmoid(this); + }; + Tensor.prototype.logSigmoid = function () { + this.throwIfDisposed(); + return opHandler.logSigmoid(this); + }; + Tensor.prototype.softplus = function () { + this.throwIfDisposed(); + return opHandler.softplus(this); + }; + Tensor.prototype.zerosLike = function () { + this.throwIfDisposed(); + return opHandler.zerosLike(this); + }; + Tensor.prototype.onesLike = function () { + this.throwIfDisposed(); + return opHandler.onesLike(this); + }; + Tensor.prototype.sin = function () { + this.throwIfDisposed(); + return opHandler.sin(this); + }; + Tensor.prototype.cos = function () { + this.throwIfDisposed(); + return opHandler.cos(this); + }; + Tensor.prototype.tan = function () { + this.throwIfDisposed(); + return opHandler.tan(this); + }; + Tensor.prototype.asin = function () { + this.throwIfDisposed(); + return opHandler.asin(this); + }; + Tensor.prototype.acos = function () { + this.throwIfDisposed(); + return opHandler.acos(this); + }; + Tensor.prototype.atan = function () { + this.throwIfDisposed(); + return opHandler.atan(this); + }; + Tensor.prototype.sinh = function () { + this.throwIfDisposed(); + return opHandler.sinh(this); + }; + Tensor.prototype.cosh = function () { + this.throwIfDisposed(); + return opHandler.cosh(this); + }; + Tensor.prototype.tanh = function () { + this.throwIfDisposed(); + return opHandler.tanh(this); + }; + Tensor.prototype.asinh = function () { + this.throwIfDisposed(); + return opHandler.asinh(this); + }; + Tensor.prototype.acosh = function () { + this.throwIfDisposed(); + return opHandler.acosh(this); + }; + Tensor.prototype.atanh = function () { + this.throwIfDisposed(); + return opHandler.atanh(this); + }; + Tensor.prototype.erf = function () { + this.throwIfDisposed(); + return opHandler.erf(this); + }; + Tensor.prototype.round = function () { + this.throwIfDisposed(); + return opHandler.round(this); + }; + Tensor.prototype.step = function (alpha) { + if (alpha === void 0) { alpha = 0.0; } + this.throwIfDisposed(); + return opHandler.step(this, alpha); + }; + Tensor.prototype.softmax = function (dim) { + if (dim === void 0) { dim = -1; } + this.throwIfDisposed(); + return opHandler.softmax(this, dim); + }; + Tensor.prototype.logSoftmax = function (axis) { + if (axis === void 0) { axis = -1; } + this.throwIfDisposed(); + return opHandler.logSoftmax(this, axis); + }; + // Image ops. + Tensor.prototype.resizeBilinear = function (newShape2D, alignCorners) { + if (alignCorners === void 0) { alignCorners = false; } + this.throwIfDisposed(); + return opHandler.image.resizeBilinear(this, newShape2D, alignCorners); + }; + Tensor.prototype.resizeNearestNeighbor = function (newShape2D, alignCorners) { + if (alignCorners === void 0) { alignCorners = false; } + this.throwIfDisposed(); + return opHandler.image.resizeNearestNeighbor(this, newShape2D, alignCorners); + }; + // Convolutions. + Tensor.prototype.conv1d = function (filter, stride, pad, dataFormat, dilation, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NWC'; } + if (dilation === void 0) { dilation = 1; } + this.throwIfDisposed(); + return opHandler.conv1d(this, filter, stride, pad, dataFormat, dilation, dimRoundingMode); + }; + Tensor.prototype.conv2d = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + if (dilations === void 0) { dilations = [1, 1]; } + this.throwIfDisposed(); + return opHandler.conv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + }; + Tensor.prototype.conv2dTranspose = function (filter, outputShape, strides, pad, dimRoundingMode) { + this.throwIfDisposed(); + return opHandler.conv2dTranspose(this, filter, outputShape, strides, pad, dimRoundingMode); + }; + Tensor.prototype.depthwiseConv2D = function (filter, strides, pad, dataFormat, dilations, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + if (dilations === void 0) { dilations = [1, 1]; } + this.throwIfDisposed(); + return opHandler.depthwiseConv2d(this, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + }; + Tensor.prototype.separableConv2d = function (depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) { + if (dilation === void 0) { dilation = [1, 1]; } + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + this.throwIfDisposed(); + return opHandler.separableConv2d(this, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat); + }; + // Pooling. + Tensor.prototype.avgPool = function (filterSize, strides, pad, dimRoundingMode) { + this.throwIfDisposed(); + return opHandler.avgPool(this, filterSize, strides, pad, dimRoundingMode); + }; + Tensor.prototype.maxPool = function (filterSize, strides, pad, dimRoundingMode) { + this.throwIfDisposed(); + return opHandler.maxPool(this, filterSize, strides, pad, dimRoundingMode); + }; + Tensor.prototype.localResponseNormalization = function (radius, bias, alpha, beta) { + if (radius === void 0) { radius = 5; } + if (bias === void 0) { bias = 1; } + if (alpha === void 0) { alpha = 1; } + if (beta === void 0) { beta = 0.5; } + return opHandler.localResponseNormalization(this, radius, bias, alpha, beta); + }; + Tensor.prototype.pool = function (windowShape, poolingType, padding, dilationRate, strides) { + this.throwIfDisposed(); + return opHandler.pool(this, windowShape, poolingType, padding, dilationRate, strides); + }; + Tensor.prototype.variable = function (trainable, name, dtype) { + if (trainable === void 0) { trainable = true; } + this.throwIfDisposed(); + return trackerFn().makeVariable(this, trainable, name, dtype); + }; + Tensor.prototype.unsortedSegmentSum = function (segmentIds, numSegments) { + this.throwIfDisposed(); + return opHandler.unsortedSegmentSum(this, segmentIds, numSegments); + }; + Tensor.prototype.batchToSpaceND = function (blockShape, crops) { + this.throwIfDisposed(); + return opHandler.batchToSpaceND(this, blockShape, crops); + }; + Tensor.prototype.spaceToBatchND = function (blockShape, paddings) { + this.throwIfDisposed(); + return opHandler.spaceToBatchND(this, blockShape, paddings); + }; + Tensor.prototype.topk = function (k, sorted) { + if (k === void 0) { k = 1; } + if (sorted === void 0) { sorted = true; } + this.throwIfDisposed(); + return opHandler.topk(this, k, sorted); + }; + Tensor.prototype.stridedSlice = function (begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) { + if (beginMask === void 0) { beginMask = 0; } + if (endMask === void 0) { endMask = 0; } + if (ellipsisMask === void 0) { ellipsisMask = 0; } + if (newAxisMask === void 0) { newAxisMask = 0; } + if (shrinkAxisMask === void 0) { shrinkAxisMask = 0; } + this.throwIfDisposed(); + return opHandler.stridedSlice(this, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask); + }; + Tensor.prototype.depthToSpace = function (blockSize, dataFormat) { + this.throwIfDisposed(); + return opHandler.depthToSpace(this, blockSize, dataFormat); + }; + Tensor.prototype.fft = function () { + this.throwIfDisposed(); + return opHandler.spectral.fft(this); + }; + Tensor.prototype.ifft = function () { + this.throwIfDisposed(); + return opHandler.spectral.ifft(this); + }; + Tensor.prototype.rfft = function () { + this.throwIfDisposed(); + return opHandler.spectral.rfft(this); + }; + Tensor.prototype.irfft = function () { + this.throwIfDisposed(); + return opHandler.spectral.irfft(this); + }; + return Tensor; + }()); + Object.defineProperty(Tensor, Symbol.hasInstance, { + value: function (instance) { + return !!instance && instance.dataId != null && instance.shape != null && + instance.dtype != null; + } + }); + /** + * A mutable `tf.Tensor`, useful for persisting state, e.g. for training. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + var Variable = /** @class */ (function (_super) { + __extends(Variable, _super); + function Variable(initialValue, trainable, name, tensorId) { + var _this = _super.call(this, initialValue.shape, initialValue.dtype, initialValue.dataId, tensorId) || this; + _this.trainable = trainable; + _this.name = name; + return _this; + } + /** + * Assign a new `tf.Tensor` to this variable. The new `tf.Tensor` must have + * the same shape and dtype as the old `tf.Tensor`. + * + * @param newValue New tensor to be assigned to this variable. + */ + /** @doc {heading: 'Tensors', subheading: 'Classes'} */ + Variable.prototype.assign = function (newValue) { + if (newValue.dtype !== this.dtype) { + throw new Error("dtype of the new value (" + newValue.dtype + ") and " + + ("previous value (" + this.dtype + ") must match")); + } + if (!arraysEqual(newValue.shape, this.shape)) { + throw new Error("shape of the new value (" + newValue.shape + ") and " + + ("previous value (" + this.shape + ") must match")); + } + trackerFn().disposeTensor(this); + this.dataId = newValue.dataId; + trackerFn().incRef(this, null /* backend */); + }; + Variable.prototype.dispose = function () { + trackerFn().disposeVariable(this); + this.isDisposedInternal = true; + }; + return Variable; + }(Tensor)); + Object.defineProperty(Variable, Symbol.hasInstance, { + value: function (instance) { + return instance instanceof Tensor && instance.assign != null && + instance.assign instanceof Function; + } + }); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + (function (Rank) { + Rank["R0"] = "R0"; + Rank["R1"] = "R1"; + Rank["R2"] = "R2"; + Rank["R3"] = "R3"; + Rank["R4"] = "R4"; + Rank["R5"] = "R5"; + Rank["R6"] = "R6"; + })(exports.Rank || (exports.Rank = {})); + // Looks for upcasting types. Used, for example, in operations with mixed dtype + // inputs. + var UpcastInt32AndMap; + (function (UpcastInt32AndMap) { + UpcastInt32AndMap["float32"] = "float32"; + UpcastInt32AndMap["int32"] = "int32"; + UpcastInt32AndMap["bool"] = "int32"; + UpcastInt32AndMap["complex64"] = "complex64"; + })(UpcastInt32AndMap || (UpcastInt32AndMap = {})); + var UpcastBoolAndMap; + (function (UpcastBoolAndMap) { + UpcastBoolAndMap["float32"] = "float32"; + UpcastBoolAndMap["int32"] = "int32"; + UpcastBoolAndMap["bool"] = "bool"; + UpcastBoolAndMap["complex64"] = "complex64"; + })(UpcastBoolAndMap || (UpcastBoolAndMap = {})); + var UpcastFloat32AndMap; + (function (UpcastFloat32AndMap) { + UpcastFloat32AndMap["float32"] = "float32"; + UpcastFloat32AndMap["int32"] = "float32"; + UpcastFloat32AndMap["bool"] = "float32"; + UpcastFloat32AndMap["complex64"] = "complex64"; + })(UpcastFloat32AndMap || (UpcastFloat32AndMap = {})); + var UpcastComplex64AndMap; + (function (UpcastComplex64AndMap) { + UpcastComplex64AndMap["float32"] = "complex64"; + UpcastComplex64AndMap["int32"] = "complex64"; + UpcastComplex64AndMap["bool"] = "complex64"; + UpcastComplex64AndMap["complex64"] = "complex64"; + })(UpcastComplex64AndMap || (UpcastComplex64AndMap = {})); + var upcastTypeMap = { + 'float32': UpcastFloat32AndMap, + 'int32': UpcastInt32AndMap, + 'bool': UpcastBoolAndMap, + 'complex64': UpcastComplex64AndMap + }; + function upcastType(typeA, typeB) { + if (typeA === 'string' || typeB === 'string') { + if (typeA === 'string' && typeB === 'string') { + return 'string'; + } + throw new Error("Can not upcast " + typeA + " with " + typeB); + } + return upcastTypeMap[typeA][typeB]; + } + /** Returns the output type after summation. */ + function sumOutType(type) { + return upcastType(type, 'int32'); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function makeTypesMatch(a, b) { + if (a.dtype === b.dtype) { + return [a, b]; + } + var dtype = upcastType(a.dtype, b.dtype); + return [a.cast(dtype), b.cast(dtype)]; + } + function assertTypesMatch(a, b) { + assert(a.dtype === b.dtype, function () { return "The dtypes of the first(" + a.dtype + ") and" + + (" second(" + b.dtype + ") input must match"); }); + } + function isTensorInList(tensor, tensorList) { + for (var i = 0; i < tensorList.length; i++) { + if (tensorList[i].id === tensor.id) { + return true; + } + } + return false; + } + /** + * Extracts any `Tensor`s found within the provided object. + * + * @param container an object that may be a `Tensor` or may directly contain + * `Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. In general it + * is safe to pass any object here, except that `Promise`s are not + * supported. + * @returns An array of `Tensors` found within the passed object. If the + * argument is simply a `Tensor', a list containing that `Tensor` is + * returned. If the object is not a `Tensor` or does not + * contain `Tensors`, an empty list is returned. + */ + function getTensorsInContainer(result) { + var list = []; + var seen = new Set(); + walkTensorContainer(result, list, seen); + return list; + } + function walkTensorContainer(container, list, seen) { + if (container == null) { + return; + } + if (container instanceof Tensor) { + list.push(container); + return; + } + if (!isIterable(container)) { + return; + } + // Iteration over keys works also for arrays. + var iterable = container; + for (var k in iterable) { + var val = iterable[k]; + if (!seen.has(val)) { + seen.add(val); + walkTensorContainer(val, list, seen); + } + } + } + // tslint:disable-next-line:no-any + function isIterable(obj) { + return Array.isArray(obj) || typeof obj === 'object'; + } + + var tensor_util = /*#__PURE__*/Object.freeze({ + makeTypesMatch: makeTypesMatch, + assertTypesMatch: assertTypesMatch, + isTensorInList: isTensorInList, + getTensorsInContainer: getTensorsInContainer + }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EngineState = /** @class */ (function () { + function EngineState() { + // Public since optimizers will use it. + this.registeredVariables = {}; + this.nextTapeNodeId = 0; + this.numBytes = 0; + this.numTensors = 0; + this.numStringTensors = 0; + this.numDataBuffers = 0; + // Number of nested tf.grad() statements when computing higher-order + // gradients. E.g. `1` for first-order gradients and `2` for second-order + // gradients. Used to track if the tape should be removed after a backprop. + this.gradientDepth = 0; + // Number of nested kernel calls. When kernel depth is greater than 1, we turn + // off the tape. + this.kernelDepth = 0; + this.scopeStack = []; + /** + * Keeps track of the number of data moves during a kernel execution. We + * maintain a stack since kernels can call other kernels, recursively. + */ + this.numDataMovesStack = []; + this.nextScopeId = 0; + this.tensorInfo = new WeakMap(); + this.profiling = false; + this.activeProfile = { newBytes: 0, newTensors: 0, peakBytes: 0, kernels: [], result: null }; + } + EngineState.prototype.dispose = function () { + for (var variableName in this.registeredVariables) { + this.registeredVariables[variableName].dispose(); + } + }; + return EngineState; + }()); + var Engine = /** @class */ (function () { + function Engine(ENV) { + this.ENV = ENV; + this.registry = {}; + this.registryFactory = {}; + this.pendingBackendInitId = 0; + this.state = new EngineState(); + } + Engine.prototype.ready = function () { + return __awaiter(this, void 0, void 0, function () { + var sortedBackends, i, backendName, success; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (this.pendingBackendInit != null) { + return [2 /*return*/, this.pendingBackendInit.then(function () { })]; + } + if (this.backendInstance != null) { + return [2 /*return*/]; + } + sortedBackends = this.getSortedBackends(); + i = 0; + _a.label = 1; + case 1: + if (!(i < sortedBackends.length)) return [3 /*break*/, 5]; + backendName = sortedBackends[i]; + return [4 /*yield*/, this.initializeBackend(backendName).success]; + case 2: + success = _a.sent(); + if (!success) return [3 /*break*/, 4]; + return [4 /*yield*/, this.setBackend(backendName)]; + case 3: + _a.sent(); + return [2 /*return*/]; + case 4: + i++; + return [3 /*break*/, 1]; + case 5: throw new Error("Could not initialize any backends, all backend initializations " + + "failed."); + } + }); + }); + }; + Object.defineProperty(Engine.prototype, "backend", { + get: function () { + if (this.pendingBackendInit != null) { + throw new Error("Backend '" + this.backendName + "' has not yet been initialized. Make " + + "sure to await tf.ready() before calling other methods"); + } + if (this.backendInstance == null) { + var _a = this.initializeBackendsAndReturnBest(), name_1 = _a.name, asyncInit = _a.asyncInit; + if (asyncInit) { + throw new Error("The highest priority backend '" + name_1 + "' has not yet been " + + "initialized. Make sure to await tf.ready() before calling " + + "other methods"); + } + this.setBackend(name_1); + } + return this.backendInstance; + }, + enumerable: true, + configurable: true + }); + Engine.prototype.backendNames = function () { + return Object.keys(this.registryFactory); + }; + Engine.prototype.findBackend = function (backendName) { + if (!(backendName in this.registry)) { + // If the backend hasn't been initialized but we have a registry entry for + // it, initialize it and return it. + if (backendName in this.registryFactory) { + var asyncInit = this.initializeBackend(backendName).asyncInit; + if (asyncInit) { + // Backend is not ready yet. + return null; + } + } + else { + return null; + } + } + return this.registry[backendName]; + }; + Engine.prototype.findBackendFactory = function (backendName) { + if (!(backendName in this.registryFactory)) { + return null; + } + return this.registryFactory[backendName].factory; + }; + Engine.prototype.registerBackend = function (backendName, factory, priority) { + if (priority === void 0) { priority = 1; } + if (backendName in this.registryFactory) { + console.warn(backendName + " backend was already registered. " + + "Reusing existing backend factory."); + return false; + } + this.registryFactory[backendName] = { factory: factory, priority: priority }; + return true; + }; + Engine.prototype.setBackend = function (backendName) { + return __awaiter(this, void 0, void 0, function () { + var _a, success, asyncInit, result, _b; + return __generator(this, function (_c) { + switch (_c.label) { + case 0: + if (this.registryFactory[backendName] == null) { + throw new Error("Backend name '" + backendName + "' not found in registry"); + } + this.backendName = backendName; + if (!(this.registry[backendName] == null)) return [3 /*break*/, 4]; + this.backendInstance = null; + _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; + if (!asyncInit) return [3 /*break*/, 2]; + return [4 /*yield*/, success]; + case 1: + _b = _c.sent(); + return [3 /*break*/, 3]; + case 2: + _b = success; + _c.label = 3; + case 3: + result = _b; + if (!result) { + return [2 /*return*/, false]; + } + _c.label = 4; + case 4: + this.backendInstance = this.registry[backendName]; + this.setupRegisteredKernels(); + // Reset the profiler. + this.profiler = new Profiler(this.backendInstance); + return [2 /*return*/, true]; + } + }); + }); + }; + Engine.prototype.setupRegisteredKernels = function () { + var _this = this; + var kernels = getKernelsForBackend(this.backendName); + kernels.forEach(function (kernel) { + if (kernel.setupFunc != null) { + kernel.setupFunc(_this.backendInstance); + } + }); + }; + Engine.prototype.disposeRegisteredKernels = function (backendName) { + var _this = this; + var kernels = getKernelsForBackend(backendName); + kernels.forEach(function (kernel) { + if (kernel.disposeFunc != null) { + kernel.disposeFunc(_this.registry[backendName]); + } + }); + }; + /** + * Initializes a backend by looking up the backend name in the factory + * registry and calling the factory method. Returns a boolean representing + * whether the initialization of the backend suceeded. Throws an error if + * there is no backend in the factory registry. + */ + Engine.prototype.initializeBackend = function (backendName) { + var _this = this; + var registryFactoryEntry = this.registryFactory[backendName]; + if (registryFactoryEntry == null) { + throw new Error("Cannot initialize backend " + backendName + ", no registration found."); + } + try { + var backend = registryFactoryEntry.factory(); + // Test if the factory returns a promise. + if (Promise.resolve(backend) === backend) { + var promiseId_1 = ++this.pendingBackendInitId; + var success = backend + .then(function (backendInstance) { + // Outdated promise. Another backend was set in the meantime. + if (promiseId_1 < _this.pendingBackendInitId) { + return false; + } + _this.registry[backendName] = backendInstance; + _this.pendingBackendInit = null; + return true; + }) + .catch(function (err) { + // Outdated promise. Another backend was set in the meantime. + if (promiseId_1 < _this.pendingBackendInitId) { + return false; + } + _this.pendingBackendInit = null; + console.warn("Initialization of backend " + backendName + " failed"); + console.warn(err.stack || err.message); + return false; + }); + this.pendingBackendInit = success; + return { success: success, asyncInit: true }; + } + else { + this.registry[backendName] = backend; + return { success: true, asyncInit: false }; + } + } + catch (err) { + console.warn("Initialization of backend " + backendName + " failed"); + console.warn(err.stack || err.message); + return { success: false, asyncInit: false }; + } + }; + Engine.prototype.removeBackend = function (backendName) { + if (!(backendName in this.registryFactory)) { + throw new Error(backendName + " backend not found in registry"); + } + if (this.backendName === backendName && this.pendingBackendInit != null) { + // There is a pending promise of the backend we want to remove. Make it + // obsolete. + this.pendingBackendInitId++; + } + if (backendName in this.registry) { + this.disposeRegisteredKernels(backendName); + this.registry[backendName].dispose(); + delete this.registry[backendName]; + } + delete this.registryFactory[backendName]; + // Unset the backend if it is active. + if (this.backendName === backendName) { + this.pendingBackendInit = null; + this.backendName = null; + this.backendInstance = null; + } + }; + Engine.prototype.getSortedBackends = function () { + var _this = this; + if (Object.keys(this.registryFactory).length === 0) { + throw new Error('No backend found in registry.'); + } + return Object.keys(this.registryFactory).sort(function (a, b) { + // Highest priority comes first. + return _this.registryFactory[b].priority - + _this.registryFactory[a].priority; + }); + }; + Engine.prototype.initializeBackendsAndReturnBest = function () { + var sortedBackends = this.getSortedBackends(); + for (var i = 0; i < sortedBackends.length; i++) { + var backendName = sortedBackends[i]; + var _a = this.initializeBackend(backendName), success = _a.success, asyncInit = _a.asyncInit; + if (asyncInit || success) { + return { name: backendName, asyncInit: asyncInit }; + } + } + throw new Error("Could not initialize any backends, all backend initializations " + + "failed."); + }; + Engine.prototype.moveData = function (destBackend, dataId) { + var info = this.state.tensorInfo.get(dataId); + var srcBackend = info.backend; + var values = this.readSync(dataId); + // Delete the tensor from the old backend and move it to the new + // backend. + srcBackend.disposeData(dataId); + info.backend = destBackend; + destBackend.move(dataId, values, info.shape, info.dtype); + if (this.shouldCheckForMemLeaks()) { + // Track the number of moves during a kernel execution to correctly + // detect memory leaks. + this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]++; + } + }; + Engine.prototype.tidy = function (nameOrFn, fn) { + var _this = this; + var name = null; + if (fn == null) { + // Called with only 1 argument. + if (typeof nameOrFn !== 'function') { + throw new Error('Please provide a function to tidy()'); + } + fn = nameOrFn; + } + else { + // Called with 2 arguments. + if (typeof nameOrFn !== 'string' && !(nameOrFn instanceof String)) { + throw new Error('When calling with two arguments, the first argument ' + + 'to tidy() must be a string'); + } + if (typeof fn !== 'function') { + throw new Error('When calling with two arguments, the 2nd argument ' + + 'to tidy() must be a function'); + } + name = nameOrFn; + // TODO(nsthorat,smilkov): Do operation logging and performance + // profiling. + } + var result; + return this.scopedRun(function () { return _this.startScope(name); }, function () { return _this.endScope(result); }, function () { + result = fn(); + if (result instanceof Promise) { + console.error('Cannot return a Promise inside of tidy.'); + } + return result; + }); + }; + Engine.prototype.scopedRun = function (start, end, f) { + start(); + try { + var res = f(); + end(); + return res; + } + catch (ex) { + end(); + throw ex; + } + }; + Engine.prototype.nextTensorId = function () { + return Engine.nextTensorId++; + }; + Engine.prototype.nextVariableId = function () { + return Engine.nextVariableId++; + }; + /** + * This method is called instead of the public-facing tensor.clone() when + * saving a tensor for backwards pass. It makes sure to add the clone + * operation to the tape regardless of being called inside a kernel + * execution. + * + * This method will go away once all kernels are modularized since we won't + * need to turn off the tape inside runKernel(). + */ + Engine.prototype.clone = function (x) { + var y = this.makeTensorFromDataId(x.dataId, x.shape, x.dtype); + var inputs = { x: x }; + var grad = function (dy) { return ({ x: function () { return dy.toFloat(); } }); }; + var saved = []; + this.addTapeNode(this.state.activeScope.name, inputs, [y], grad, saved); + return y; + }; + /** + * Execute a kernel with the given name and return the output tensor. + * + * @param kernelName The name of the kernel to execute. + * @param inputs A map of input names to tensors. + * @param attrs A map of attribute names to their values. An attribute is a + * primitive (non-tensor) input to the kernel. + * @param inputsToSave A list of tensors, inputs to save for the backprop + * computation. + * @param outputsToSave A list of booleans, specifying which output to save + * for the backprop computation. These are booleans since the output + * tensors are not visible to the user. + */ + Engine.prototype.runKernel = function (kernelName, inputs, attrs, inputsToSave, outputsToSave) { + var forwardFunc = null; + var backwardsFunc = null; + // Call runKernel as a stop-gap until we modularize all kernels. + // Once we modularize all kernels, we will remove the existing + // `runKernelFunc`. + return this.runKernelFunc(forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave); + }; + Engine.prototype.shouldCheckForMemLeaks = function () { + return this.ENV.getBool('IS_TEST'); + }; + Engine.prototype.checkKernelForMemLeak = function (scopeName, numDataIdsBefore, outInfos) { + var numDataIdsAfter = this.backend.numDataIds(); + // Count the number of data ids associated with the result of the kernel. + var numOutputDataIds = 0; + outInfos.forEach(function (info) { + // Complex numbers allocate 3 data ids, one for 'real', one for + // 'imaginary', and one for the container that holds the former two. + numOutputDataIds += (info.dtype === 'complex64' ? 3 : 1); + }); + // Account for the number of moves during kernel execution. A "data move" + // can happen in the middle of a kernel execution, placing a new (key,value) + // pair in the data storage. Since data moves have net zero effect (we + // always remove the data from the old backend), we have to cancel them out + // when detecting memory leaks. + var numMoves = this.state.numDataMovesStack[this.state.numDataMovesStack.length - 1]; + var dataIdsLeaked = numDataIdsAfter - numDataIdsBefore - numOutputDataIds - numMoves; + if (dataIdsLeaked > 0) { + throw new Error("Backend '" + this.backendName + "' has an internal memory leak " + + ("(" + dataIdsLeaked + " data ids) after running '" + scopeName + "'")); + } + }; + /** + * @deprecated Use `runKernel` for newly added kernels. Keep using this method + * only for kernels that are not yet fully modularized. + */ + Engine.prototype.runKernelFunc = function (forwardFunc, inputs, backwardsFunc, kernelName, attrs, inputsToSave, outputsToSave) { + var _this = this; + if (inputsToSave === void 0) { inputsToSave = []; } + if (outputsToSave === void 0) { outputsToSave = []; } + var outputs; + var saved = []; + var isTapeOn = this.isTapeOn(); + var scopeName = this.state.activeScope != null ? this.state.activeScope.name : ''; + var saveFunc = function (tensors) { + // Do not save unless we are recording to the tape. Otherwise it would + // cause a mem leak since we would never run backprop, which disposes + // the kept tensors. + if (!isTapeOn) { + return; + } + saved = tensors.map(function (tensor) { return _this.keep(_this.clone(tensor)); }); + }; + var startingBytecount = this.state.numBytes; + var startingNumTensors = this.state.numTensors; + if (this.shouldCheckForMemLeaks()) { + this.state.numDataMovesStack.push(0); + } + var kernelFunc; + var kernel = getKernel(kernelName, this.backendName); + var out; + if (kernel != null) { + kernelFunc = function () { + var numDataIdsBefore = _this.backend.numDataIds(); + out = kernel.kernelFunc({ inputs: inputs, attrs: attrs, backend: _this.backend }); + var outInfos = Array.isArray(out) ? out : [out]; + if (_this.shouldCheckForMemLeaks()) { + _this.checkKernelForMemLeak(scopeName, numDataIdsBefore, outInfos); + } + var outTensors = outInfos.map(function (_a) { + var dataId = _a.dataId, shape = _a.shape, dtype = _a.dtype; + return _this.makeTensorFromDataId(dataId, shape, dtype); + }); + var outsToSave = outTensors.filter(function (_, i) { return outputsToSave[i]; }); + // Save the inputs and outputs. + saveFunc(inputsToSave.slice().concat(outsToSave)); + return outTensors; + }; + } + else { + kernelFunc = function () { + var numDataIdsBefore = _this.backend.numDataIds(); + out = _this.tidy(function () { return forwardFunc(_this.backend, saveFunc); }); + var outs = (Array.isArray(out) ? out : [out]); + if (_this.shouldCheckForMemLeaks()) { + _this.checkKernelForMemLeak(scopeName, numDataIdsBefore, outs); + } + return outs; + }; + } + // Stop recording to a tape when running a kernel. + this.scopedRun(function () { return _this.state.kernelDepth++; }, function () { return _this.state.kernelDepth--; }, function () { + if (!_this.ENV.getBool('DEBUG')) { + outputs = kernelFunc(); + } + else { + outputs = _this.profiler.profileKernel(scopeName, inputs, function () { return kernelFunc(); }); + } + }); + if (isTapeOn) { + this.addTapeNode(scopeName, inputs, outputs, backwardsFunc, saved); + } + if (this.state.profiling) { + this.state.activeProfile.kernels.push({ + name: scopeName, + bytesAdded: this.state.numBytes - startingBytecount, + totalBytesSnapshot: this.state.numBytes, + tensorsAdded: this.state.numTensors - startingNumTensors, + totalTensorsSnapshot: this.state.numTensors, + inputShapes: Object.keys(inputs).map(function (key) { return inputs[key].shape; }), + outputShapes: outputs.map(function (item) { return item.shape; }) + }); + } + return (Array.isArray(out) ? outputs : outputs[0]); + }; + /** + * Internal method used by public APIs for tensor creation. Makes a new + * tensor with the provided shape, dtype and values. It always + * creates a new data id and writes the values to the underlying backend. + */ + Engine.prototype.makeTensor = function (values, shape, dtype, backend) { + if (values == null) { + throw new Error('Values passed to engine.makeTensor() are null'); + } + dtype = dtype || 'float32'; + backend = backend || this.backend; + var backendVals = values; + if (dtype === 'string' && isString(values[0])) { + backendVals = values.map(function (d) { return encodeString(d); }); + } + var dataId = backend.write(backendVals, shape, dtype); + var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); + this.incRef(t, backend); + // Count bytes for string tensors. + if (dtype === 'string') { + var info = this.state.tensorInfo.get(dataId); + var newBytes = bytesFromStringArray(backendVals); + this.state.numBytes += newBytes - info.bytes; + info.bytes = newBytes; + } + return t; + }; + /** + * Internal method used by backends. Makes a new tensor + * that is a wrapper around an existing data id. It doesn't create + * a new data id, only increments the ref count used in memory tracking. + */ + Engine.prototype.makeTensorFromDataId = function (dataId, shape, dtype, backend) { + dtype = dtype || 'float32'; + var t = new Tensor(shape, dtype, dataId, this.nextTensorId()); + this.incRef(t, backend); + return t; + }; + Engine.prototype.makeVariable = function (initialValue, trainable, name, dtype) { + if (trainable === void 0) { trainable = true; } + name = name || this.nextVariableId().toString(); + if (dtype != null && dtype !== initialValue.dtype) { + initialValue = initialValue.asType(dtype); + } + var v = new Variable(initialValue, trainable, name, this.nextTensorId()); + if (this.state.registeredVariables[v.name] != null) { + throw new Error("Variable with name " + v.name + " was already registered"); + } + this.state.registeredVariables[v.name] = v; + this.incRef(v, this.backend); + return v; + }; + Engine.prototype.incRef = function (a, backend) { + var refCount = this.state.tensorInfo.has(a.dataId) ? + this.state.tensorInfo.get(a.dataId).refCount : + 0; + this.state.numTensors++; + if (a.dtype === 'string') { + this.state.numStringTensors++; + } + if (refCount === 0) { + this.state.numDataBuffers++; + // Bytes for complex numbers are counted by their components. Bytes for + // string tensors are counted when writing values. + var bytes = 0; + if (a.dtype !== 'complex64' && a.dtype !== 'string') { + bytes = a.size * bytesPerElement(a.dtype); + } + this.state.tensorInfo.set(a.dataId, { + backend: backend || this.backend, + dtype: a.dtype, + shape: a.shape, + bytes: bytes, + refCount: 0 + }); + this.state.numBytes += bytes; + } + this.state.tensorInfo.get(a.dataId).refCount++; + if (!(a instanceof Variable)) { + this.track(a); + } + }; + Engine.prototype.disposeTensor = function (a) { + if (!this.state.tensorInfo.has(a.dataId)) { + return; + } + this.state.numTensors--; + if (a.dtype === 'string') { + this.state.numStringTensors--; + } + var info = this.state.tensorInfo.get(a.dataId); + var refCount = info.refCount; + if (refCount <= 1) { + // Don't count bytes for complex numbers as they are counted by their + // components. + if (a.dtype !== 'complex64') { + this.state.numBytes -= info.bytes; + } + this.state.numDataBuffers--; + info.backend.disposeData(a.dataId); + this.state.tensorInfo.delete(a.dataId); + } + else { + this.state.tensorInfo.get(a.dataId).refCount--; + } + // TODO(nsthorat): Construct an error and save the stack trace for + // debugging when in debug mode. Creating a stack trace is too expensive + // to do unconditionally. + }; + Engine.prototype.disposeVariables = function () { + for (var varName in this.state.registeredVariables) { + var v = this.state.registeredVariables[varName]; + this.disposeVariable(v); + } + }; + Engine.prototype.disposeVariable = function (v) { + this.disposeTensor(v); + if (this.state.registeredVariables[v.name] != null) { + delete this.state.registeredVariables[v.name]; + } + }; + Engine.prototype.memory = function () { + var info = this.backend.memory(); + info.numTensors = this.state.numTensors; + info.numDataBuffers = this.state.numDataBuffers; + info.numBytes = this.state.numBytes; + if (this.state.numStringTensors > 0) { + info.unreliable = true; + if (info.reasons == null) { + info.reasons = []; + } + info.reasons.push('Memory usage by string tensors is approximate ' + + '(2 bytes per character)'); + } + return info; + }; + Engine.prototype.profile = function (query) { + return __awaiter(this, void 0, void 0, function () { + var startBytes, startNumTensors; + return __generator(this, function (_a) { + this.state.profiling = true; + startBytes = this.state.numBytes; + startNumTensors = this.state.numTensors; + this.state.activeProfile.kernels = []; + this.state.activeProfile.result = query(); + this.state.profiling = false; + this.state.activeProfile.peakBytes = Math.max.apply(Math, this.state.activeProfile.kernels.map(function (d) { return d.totalBytesSnapshot; })); + this.state.activeProfile.newBytes = this.state.numBytes - startBytes; + this.state.activeProfile.newTensors = + this.state.numTensors - startNumTensors; + return [2 /*return*/, this.state.activeProfile]; + }); + }); + }; + Engine.prototype.isTapeOn = function () { + return this.state.gradientDepth > 0 && this.state.kernelDepth === 0; + }; + Engine.prototype.addTapeNode = function (scopeName, inputs, outputs, gradientsFunc, saved) { + var _this = this; + var tapeNode = { + id: this.state.nextTapeNodeId++, + name: scopeName, + inputs: inputs, + outputs: outputs, + saved: saved + }; + if (gradientsFunc != null) { + tapeNode.gradient = function (dys) { + // TODO(smilkov): To optimize back-prop, pass dys that are not used in + // the backprop graph to the user as null instead of zeros + dys = dys.map(function (dy, i) { + if (dy == null) { + var output = outputs[i]; + var vals = makeZerosTypedArray(output.size, output.dtype); + return _this.makeTensor(vals, output.shape, output.dtype); + } + return dy; + }); + // Grad functions of ops with single outputs expect a dy, while ops + // with multiple outputs expect dys (array of dy). + return gradientsFunc(dys.length > 1 ? dys : dys[0], saved); + }; + } + this.state.activeTape.push(tapeNode); + }; + Engine.prototype.keep = function (result) { + result.kept = true; + return result; + }; + Engine.prototype.startTape = function () { + if (this.state.gradientDepth === 0) { + this.state.activeTape = []; + } + this.state.gradientDepth++; + }; + Engine.prototype.endTape = function () { + this.state.gradientDepth--; + }; + /** + * Start a scope. Use this with endScope() to achieve the same functionality + * as scope() without the need for a function closure. + */ + Engine.prototype.startScope = function (name) { + var scopeInfo = { + track: [], + name: 'unnamed scope', + id: this.state.nextScopeId++ + }; + if (name) { + scopeInfo.name = name; + } + this.state.scopeStack.push(scopeInfo); + this.state.activeScope = scopeInfo; + }; + /** + * End a scope. Use this with startScope() to achieve the same functionality + * as scope() without the need for a function closure. + */ + Engine.prototype.endScope = function (result) { + var _this = this; + var tensorsToTrackInParent = getTensorsInContainer(result); + var tensorsToTrackInParentSet = new Set(tensorsToTrackInParent.map(function (t) { return t.id; })); + // Dispose the arrays tracked in this scope. + for (var i = 0; i < this.state.activeScope.track.length; i++) { + var tensor = this.state.activeScope.track[i]; + if (!tensor.kept && !tensorsToTrackInParentSet.has(tensor.id)) { + tensor.dispose(); + } + } + var oldScope = this.state.scopeStack.pop(); + this.state.activeScope = this.state.scopeStack.length === 0 ? + null : + this.state.scopeStack[this.state.scopeStack.length - 1]; + // Track the current result in the parent scope. + tensorsToTrackInParent.forEach(function (tensor) { + // Only track the tensor if was allocated in the inner scope and is not + // globally kept. + if (!tensor.kept && tensor.scopeId === oldScope.id) { + _this.track(tensor); + } + }); + }; + /** + * Returns gradients of `f` with respect to each of the `xs`. The gradients + * returned are of the same length as `xs`, but some might be null if `f` + * was not a function of that `x`. It also takes optional dy to multiply the + * gradient, which defaults to `1`. + */ + Engine.prototype.gradients = function (f, xs, dy, allowNoGradients) { + var _this = this; + if (allowNoGradients === void 0) { allowNoGradients = false; } + assert(xs.length > 0, function () { return 'gradients() received an empty list of xs.'; }); + if (dy != null && dy.dtype !== 'float32') { + throw new Error("dy must have 'float32' dtype, but has '" + dy.dtype + "'"); + } + var y = this.scopedRun(function () { return _this.startTape(); }, function () { return _this.endTape(); }, function () { return _this.tidy('forward', f); }); + assert(y instanceof Tensor, function () { return 'The result y returned by f() must be a tensor.'; }); + // Filter out the nodes that don't connect x => y. + var filteredTape = getFilteredNodesXToY(this.state.activeTape, xs, y); + if (!allowNoGradients && filteredTape.length === 0 && xs.length > 0) { + throw new Error('Cannot compute gradient of y=f(x) with respect to x. Make sure ' + + 'that the f you passed encloses all operations that lead from x ' + + 'to y.'); + } + return this.tidy('backward', function () { + var accumulatedGradientMap = {}; + accumulatedGradientMap[y.id] = (dy == null) ? ones(y.shape) : dy; + // Backprop gradients through the filtered nodes. + backpropagateGradients(accumulatedGradientMap, filteredTape, + // Pass the tidy function to avoid circular dep with `tape.ts`. + function (f) { return _this.tidy(f); }); + var grads = xs.map(function (x) { return accumulatedGradientMap[x.id]; }); + if (_this.state.gradientDepth === 0) { + // This means that we are not computing higher-order gradients + // and can clean up the tape. + _this.state.activeTape.forEach(function (node) { + for (var _i = 0, _a = node.saved; _i < _a.length; _i++) { + var tensor = _a[_i]; + tensor.dispose(); + } + }); + _this.state.activeTape = null; + } + return { value: y, grads: grads }; + }); + }; + Engine.prototype.customGrad = function (f) { + var _this = this; + assert(isFunction(f), function () { return 'The f passed in customGrad(f) must be a function.'; }); + return function () { + var inputs = []; + for (var _i = 0; _i < arguments.length; _i++) { + inputs[_i] = arguments[_i]; + } + assert(inputs.every(function (t) { return t instanceof Tensor; }), function () { return 'The args passed in customGrad(f)(x1, x2,...) must all be ' + + 'tensors'; }); + var res; + var inputMap = {}; + inputs.forEach(function (input, i) { + inputMap[i] = input; + }); + return _this.runKernelFunc(function (_, save) { + res = f.apply(void 0, inputs.concat([save])); + assert(res.value instanceof Tensor, function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.value` is a tensor'; }); + assert(isFunction(res.gradFunc), function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.gradFunc` is a function.'; }); + return res.value; + }, inputMap, function (dy, saved) { + var gradRes = res.gradFunc(dy, saved); + var grads = Array.isArray(gradRes) ? gradRes : [gradRes]; + assert(grads.length === inputs.length, function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.gradFunc` is a function that returns ' + + 'the same number of tensors as inputs passed to f(...).'; }); + assert(grads.every(function (t) { return t instanceof Tensor; }), function () { return 'The function f passed in customGrad(f) must return an ' + + 'object where `obj.gradFunc` is a function that returns ' + + 'a list of only tensors.'; }); + var gradMap = {}; + grads.forEach(function (grad, i) { + gradMap[i] = function () { return grad; }; + }); + return gradMap; + }); + }; + }; + Engine.prototype.readSync = function (dataId) { + // Route the read to the correct backend. + var info = this.state.tensorInfo.get(dataId); + return info.backend.readSync(dataId); + }; + Engine.prototype.read = function (dataId) { + // Route the read to the correct backend. + var info = this.state.tensorInfo.get(dataId); + return info.backend.read(dataId); + }; + Engine.prototype.fromPixels = function (pixels, numChannels) { + return this.backend.fromPixels(pixels, numChannels); + }; + Engine.prototype.time = function (query) { + return __awaiter(this, void 0, void 0, function () { + var start, timingInfo; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + start = now(); + return [4 /*yield*/, this.backend.time(query)]; + case 1: + timingInfo = _a.sent(); + timingInfo.wallMs = now() - start; + return [2 /*return*/, timingInfo]; + } + }); + }); + }; + /** + * Tracks a Tensor in the current scope to be automatically cleaned up + * when the current scope ends, and returns the value. + * + * @param result The Tensor to track in the current scope. + */ + Engine.prototype.track = function (result) { + if (this.state.activeScope != null) { + result.scopeId = this.state.activeScope.id; + this.state.activeScope.track.push(result); + } + return result; + }; + Object.defineProperty(Engine.prototype, "registeredVariables", { + get: function () { + return this.state.registeredVariables; + }, + enumerable: true, + configurable: true + }); + /** + * Resets the engine state. Removes all backends but does not remove + * registered backend factories. + */ + Engine.prototype.reset = function () { + // Make any pending promise obsolete. + this.pendingBackendInitId++; + this.state.dispose(); + this.ENV.reset(); + this.state = new EngineState(); + for (var backendName in this.registry) { + this.disposeRegisteredKernels(backendName); + this.registry[backendName].dispose(); + delete this.registry[backendName]; + } + this.backendName = null; + this.backendInstance = null; + this.pendingBackendInit = null; + }; + Engine.nextTensorId = 0; + Engine.nextVariableId = 0; + return Engine; + }()); + function ones(shape) { + var values = makeOnesTypedArray(sizeFromShape(shape), 'float32'); + return ENGINE.makeTensor(values, shape, 'float32'); + } + var GLOBAL; + function getGlobalNamespace() { + if (GLOBAL == null) { + // tslint:disable-next-line:no-any + var ns = void 0; + if (typeof (window) !== 'undefined') { + ns = window; + } + else if (typeof (global) !== 'undefined') { + ns = global; + } + else if (typeof (process) !== 'undefined') { + ns = process; + } + else if (typeof (self) !== 'undefined') { + ns = self; + } + else { + throw new Error('Could not find a global object'); + } + GLOBAL = ns; + } + return GLOBAL; + } + function getOrMakeEngine() { + var ns = getGlobalNamespace(); + if (ns._tfengine == null) { + var environment = new Environment(ns); + ns._tfengine = new Engine(environment); + } + setEnvironmentGlobal(ns._tfengine.ENV); + // Tell the current tensor interface that the global engine is responsible + // for tracking. + setTensorTracker(function () { return ns._tfengine; }); + return ns._tfengine; + } + var ENGINE = getOrMakeEngine(); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function isMobile() { + // tslint:disable-next-line:no-any + var a = navigator.userAgent || navigator.vendor || window.opera; + // tslint:disable-next-line:max-line-length + return /(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i + .test(a) || + // tslint:disable-next-line:max-line-length + /1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i + .test(a.substr(0, 4)); + } + function isBrowser() { + return (typeof window !== 'undefined' && window.document != null) || + //@ts-ignore + (typeof WorkerGlobalScope !== 'undefined'); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ENV = env(); + /** + * This file contains environment-related flag registrations. + */ + /** Whether to enable debug mode. */ + ENV.registerFlag('DEBUG', function () { return false; }, function (debugValue) { + if (debugValue) { + console.warn('Debugging mode is ON. The output of every math call will ' + + 'be downloaded to CPU and checked for NaNs. ' + + 'This significantly impacts performance.'); + } + }); + /** Whether we are in a browser (as versus, say, node.js) environment. */ + ENV.registerFlag('IS_BROWSER', function () { return isBrowser(); }); + /** Whether we are in a browser (as versus, say, node.js) environment. */ + ENV.registerFlag('IS_NODE', function () { return (typeof process !== 'undefined') && + (typeof process.versions !== 'undefined') && + (typeof process.versions.node !== 'undefined'); }); + /** Whether this browser is Chrome. */ + ENV.registerFlag('IS_CHROME', function () { return typeof navigator !== 'undefined' && navigator != null && + navigator.userAgent != null && /Chrome/.test(navigator.userAgent) && + /Google Inc/.test(navigator.vendor); }); + /** + * True when the environment is "production" where we disable safety checks + * to gain performance. + */ + ENV.registerFlag('PROD', function () { return false; }); + /** + * Whether to do sanity checks when inferring a shape from user-provided + * values, used when creating a new tensor. + */ + ENV.registerFlag('TENSORLIKE_CHECK_SHAPE_CONSISTENCY', function () { return ENV.getBool('DEBUG'); }); + /** Whether deprecation warnings are enabled. */ + ENV.registerFlag('DEPRECATION_WARNINGS_ENABLED', function () { return true; }); + /** True if running unit tests. */ + ENV.registerFlag('IS_TEST', function () { return false; }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var contexts = {}; + var WEBGL_ATTRIBUTES = { + alpha: false, + antialias: false, + premultipliedAlpha: false, + preserveDrawingBuffer: false, + depth: false, + stencil: false, + failIfMajorPerformanceCaveat: true + }; + function setWebGLContext(webGLVersion, gl) { + contexts[webGLVersion] = gl; + } + function getWebGLContext(webGLVersion) { + if (!(webGLVersion in contexts)) { + contexts[webGLVersion] = getWebGLRenderingContext(webGLVersion); + } + var gl = contexts[webGLVersion]; + if (gl.isContextLost()) { + delete contexts[webGLVersion]; + return getWebGLContext(webGLVersion); + } + gl.disable(gl.DEPTH_TEST); + gl.disable(gl.STENCIL_TEST); + gl.disable(gl.BLEND); + gl.disable(gl.DITHER); + gl.disable(gl.POLYGON_OFFSET_FILL); + gl.disable(gl.SAMPLE_COVERAGE); + gl.enable(gl.SCISSOR_TEST); + gl.enable(gl.CULL_FACE); + gl.cullFace(gl.BACK); + return contexts[webGLVersion]; + } + function createCanvas(webGLVersion) { + if (typeof OffscreenCanvas !== 'undefined' && webGLVersion === 2) { + return new OffscreenCanvas(300, 150); + } + else if (typeof document !== 'undefined') { + return document.createElement('canvas'); + } + else { + throw new Error('Cannot create a canvas in this context'); + } + } + function getWebGLRenderingContext(webGLVersion) { + if (webGLVersion !== 1 && webGLVersion !== 2) { + throw new Error('Cannot get WebGL rendering context, WebGL is disabled.'); + } + var canvas = createCanvas(webGLVersion); + canvas.addEventListener('webglcontextlost', function (ev) { + ev.preventDefault(); + delete contexts[webGLVersion]; + }, false); + if (webGLVersion === 1) { + return (canvas.getContext('webgl', WEBGL_ATTRIBUTES) || + canvas.getContext('experimental-webgl', WEBGL_ATTRIBUTES)); + } + return canvas.getContext('webgl2', WEBGL_ATTRIBUTES); + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PackingScheme; + (function (PackingScheme) { + /** + * All values in a single texel are densely packed without any constraints. + * + * This is how the shader encodes a tensor with shape = [2, 3, 4] + * (indices are [batch, row, col]). + * + * 000|001 010|011 020|021 + * ------- ------- ------- + * 002|003 012|013 022|023 + * + * 100|101 110|111 120|121 + * ------- ------- ------- + * 102|103 112|113 122|123 + * + */ + PackingScheme[PackingScheme["DENSE"] = 0] = "DENSE"; + /** + * Single texels contain only values from the same batch, and from adjacent + * rows and columns. + * + * This is how the shader encodes a tensor with shape = [2, 3, 5] + * (indices are [batch, row, col]). + * + * 000|001 002|003 004|xxx 020|021 022|023 024|xxx + * ------- ------- ------- ------- ------- ------- + * 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx + * + * 100|101 102|103 104|xxx 120|121 122|123 124|xxx + * ------- ------- ------- ------- ------- ------- + * 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx + * + */ + PackingScheme[PackingScheme["SHARED_BATCH"] = 1] = "SHARED_BATCH"; + })(PackingScheme || (PackingScheme = {})); + var TextureUsage; + (function (TextureUsage) { + TextureUsage[TextureUsage["RENDER"] = 0] = "RENDER"; + TextureUsage[TextureUsage["UPLOAD"] = 1] = "UPLOAD"; + TextureUsage[TextureUsage["PIXELS"] = 2] = "PIXELS"; + TextureUsage[TextureUsage["DOWNLOAD"] = 3] = "DOWNLOAD"; + })(TextureUsage || (TextureUsage = {})); + var PhysicalTextureType; + (function (PhysicalTextureType) { + PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT16"] = 0] = "UNPACKED_FLOAT16"; + PhysicalTextureType[PhysicalTextureType["UNPACKED_FLOAT32"] = 1] = "UNPACKED_FLOAT32"; + PhysicalTextureType[PhysicalTextureType["PACKED_4X1_UNSIGNED_BYTE"] = 2] = "PACKED_4X1_UNSIGNED_BYTE"; + PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT32"] = 3] = "PACKED_2X2_FLOAT32"; + PhysicalTextureType[PhysicalTextureType["PACKED_2X2_FLOAT16"] = 4] = "PACKED_2X2_FLOAT16"; + })(PhysicalTextureType || (PhysicalTextureType = {})); + function getUnpackedMatrixTextureShapeWidthHeight(rows, columns) { + return [columns, rows]; + } + function getUnpackedArraySizeFromMatrixSize(matrixSize, channelsPerTexture) { + return matrixSize * channelsPerTexture; + } + /** + * Get shape for densely packed RGBA texture. + */ + function getDenseTexShape(shape) { + var size = sizeFromShape(shape); + var texelsNeeded = Math.ceil(size / 4); + return sizeToSquarishShape(texelsNeeded); + } + function getPackedMatrixTextureShapeWidthHeight(rows, columns) { + return [ + Math.max(1, Math.ceil(columns / 2)), Math.max(1, Math.ceil(rows / 2)) + ]; + } + function getPackedRGBAArraySizeFromMatrixShape(rows, columns) { + var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1]; + return w * h * 4; + } + function getTextureConfig( + // tslint:disable-next-line:no-any + gl, textureHalfFloatExtension) { + // tslint:disable-next-line:no-any + var glany = gl; + var internalFormatFloat; + var internalFormatHalfFloat; + var internalFormatPackedHalfFloat; + var internalFormatPackedFloat; + var textureFormatFloat; + var downloadTextureFormat; + var downloadUnpackNumChannels; + var defaultNumChannels; + var textureTypeHalfFloat; + var textureTypeFloat; + if (env().getNumber('WEBGL_VERSION') === 2) { + internalFormatFloat = glany.R32F; + internalFormatHalfFloat = glany.R16F; + internalFormatPackedHalfFloat = glany.RGBA16F; + internalFormatPackedFloat = glany.RGBA32F; + textureFormatFloat = glany.RED; + downloadUnpackNumChannels = 4; + defaultNumChannels = 1; + textureTypeHalfFloat = glany.HALF_FLOAT; + textureTypeFloat = glany.FLOAT; + } + else { + internalFormatFloat = gl.RGBA; + internalFormatHalfFloat = gl.RGBA; + internalFormatPackedHalfFloat = gl.RGBA; + internalFormatPackedFloat = glany.RGBA; + textureFormatFloat = gl.RGBA; + downloadUnpackNumChannels = 4; + defaultNumChannels = 4; + textureTypeHalfFloat = textureHalfFloatExtension != null ? + textureHalfFloatExtension.HALF_FLOAT_OES : + null; + textureTypeFloat = gl.FLOAT; + } + downloadTextureFormat = gl.RGBA; + return { + internalFormatFloat: internalFormatFloat, + internalFormatHalfFloat: internalFormatHalfFloat, + internalFormatPackedHalfFloat: internalFormatPackedHalfFloat, + internalFormatPackedFloat: internalFormatPackedFloat, + textureFormatFloat: textureFormatFloat, + downloadTextureFormat: downloadTextureFormat, + downloadUnpackNumChannels: downloadUnpackNumChannels, + defaultNumChannels: defaultNumChannels, + textureTypeHalfFloat: textureTypeHalfFloat, + textureTypeFloat: textureTypeFloat + }; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function callAndCheck(gl, debugMode, func) { + var returnValue = func(); + if (debugMode) { + checkWebGLError(gl); + } + return returnValue; + } + function checkWebGLError(gl) { + var error = gl.getError(); + if (error !== gl.NO_ERROR) { + throw new Error('WebGL Error: ' + getWebGLErrorMessage(gl, error)); + } + } + // https://en.wikipedia.org/wiki/Half-precision_floating-point_format + var MIN_FLOAT16 = 5.96e-8; + var MAX_FLOAT16 = 65504; + function canBeRepresented(num) { + if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED') || num === 0 || + (MIN_FLOAT16 < Math.abs(num) && Math.abs(num) < MAX_FLOAT16)) { + return true; + } + return false; + } + function getWebGLErrorMessage(gl, status) { + switch (status) { + case gl.NO_ERROR: + return 'NO_ERROR'; + case gl.INVALID_ENUM: + return 'INVALID_ENUM'; + case gl.INVALID_VALUE: + return 'INVALID_VALUE'; + case gl.INVALID_OPERATION: + return 'INVALID_OPERATION'; + case gl.INVALID_FRAMEBUFFER_OPERATION: + return 'INVALID_FRAMEBUFFER_OPERATION'; + case gl.OUT_OF_MEMORY: + return 'OUT_OF_MEMORY'; + case gl.CONTEXT_LOST_WEBGL: + return 'CONTEXT_LOST_WEBGL'; + default: + return "Unknown error code " + status; + } + } + function getExtensionOrThrow(gl, debug, extensionName) { + return throwIfNull(gl, debug, function () { return gl.getExtension(extensionName); }, 'Extension "' + extensionName + '" not supported on this browser.'); + } + function createVertexShader(gl, debug, vertexShaderSource) { + var vertexShader = throwIfNull(gl, debug, function () { return gl.createShader(gl.VERTEX_SHADER); }, 'Unable to create vertex WebGLShader.'); + callAndCheck(gl, debug, function () { return gl.shaderSource(vertexShader, vertexShaderSource); }); + callAndCheck(gl, debug, function () { return gl.compileShader(vertexShader); }); + if (gl.getShaderParameter(vertexShader, gl.COMPILE_STATUS) === false) { + console.log(gl.getShaderInfoLog(vertexShader)); + throw new Error('Failed to compile vertex shader.'); + } + return vertexShader; + } + function createFragmentShader(gl, debug, fragmentShaderSource) { + var fragmentShader = throwIfNull(gl, debug, function () { return gl.createShader(gl.FRAGMENT_SHADER); }, 'Unable to create fragment WebGLShader.'); + callAndCheck(gl, debug, function () { return gl.shaderSource(fragmentShader, fragmentShaderSource); }); + callAndCheck(gl, debug, function () { return gl.compileShader(fragmentShader); }); + if (gl.getShaderParameter(fragmentShader, gl.COMPILE_STATUS) === false) { + logShaderSourceAndInfoLog(fragmentShaderSource, gl.getShaderInfoLog(fragmentShader)); + throw new Error('Failed to compile fragment shader.'); + } + return fragmentShader; + } + var lineNumberRegex = /ERROR: [0-9]+:([0-9]+):/g; + function logShaderSourceAndInfoLog(shaderSource, shaderInfoLog) { + var lineNumberRegexResult = lineNumberRegex.exec(shaderInfoLog); + if (lineNumberRegexResult == null) { + console.log("Couldn't parse line number in error: " + shaderInfoLog); + console.log(shaderSource); + return; + } + var lineNumber = +lineNumberRegexResult[1]; + var shaderLines = shaderSource.split('\n'); + var pad = shaderLines.length.toString().length + 2; + var linesWithLineNumbers = shaderLines.map(function (line, lineNumber) { + return rightPad((lineNumber + 1).toString(), pad) + line; + }); + var maxLineLength = 0; + for (var i = 0; i < linesWithLineNumbers.length; i++) { + maxLineLength = Math.max(linesWithLineNumbers[i].length, maxLineLength); + } + var beforeErrorLines = linesWithLineNumbers.slice(0, lineNumber - 1); + var errorLine = linesWithLineNumbers.slice(lineNumber - 1, lineNumber); + var afterErrorLines = linesWithLineNumbers.slice(lineNumber); + console.log(beforeErrorLines.join('\n')); + console.log(shaderInfoLog.split('\n')[0]); + console.log("%c " + rightPad(errorLine[0], maxLineLength), 'border:1px solid red; background-color:#e3d2d2; color:#a61717'); + console.log(afterErrorLines.join('\n')); + } + function createProgram(gl, debug) { + return throwIfNull(gl, debug, function () { return gl.createProgram(); }, 'Unable to create WebGLProgram.'); + } + function linkProgram(gl, debug, program) { + callAndCheck(gl, debug, function () { return gl.linkProgram(program); }); + if (gl.getProgramParameter(program, gl.LINK_STATUS) === false) { + console.log(gl.getProgramInfoLog(program)); + throw new Error('Failed to link vertex and fragment shaders.'); + } + } + function validateProgram(gl, debug, program) { + callAndCheck(gl, debug, function () { return gl.validateProgram(program); }); + if (gl.getProgramParameter(program, gl.VALIDATE_STATUS) === false) { + console.log(gl.getProgramInfoLog(program)); + throw new Error('Shader program validation failed.'); + } + } + function createStaticVertexBuffer(gl, debug, data) { + var buffer = throwIfNull(gl, debug, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer'); + callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); }); + callAndCheck(gl, debug, function () { return gl.bufferData(gl.ARRAY_BUFFER, data, gl.STATIC_DRAW); }); + return buffer; + } + function createStaticIndexBuffer(gl, debug, data) { + var buffer = throwIfNull(gl, debug, function () { return gl.createBuffer(); }, 'Unable to create WebGLBuffer'); + callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, buffer); }); + callAndCheck(gl, debug, function () { return gl.bufferData(gl.ELEMENT_ARRAY_BUFFER, data, gl.STATIC_DRAW); }); + return buffer; + } + function getNumChannels() { + if (env().getNumber('WEBGL_VERSION') === 2) { + return 1; + } + return 4; + } + function createTexture(gl, debug) { + return throwIfNull(gl, debug, function () { return gl.createTexture(); }, 'Unable to create WebGLTexture.'); + } + function validateTextureSize(width, height) { + var maxTextureSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + if ((width <= 0) || (height <= 0)) { + var requested = "[" + width + "x" + height + "]"; + throw new Error('Requested texture size ' + requested + ' is invalid.'); + } + if ((width > maxTextureSize) || (height > maxTextureSize)) { + var requested = "[" + width + "x" + height + "]"; + var max = "[" + maxTextureSize + "x" + maxTextureSize + "]"; + throw new Error('Requested texture size ' + requested + + ' greater than WebGL maximum on this browser / GPU ' + max + '.'); + } + } + function createFramebuffer(gl, debug) { + return throwIfNull(gl, debug, function () { return gl.createFramebuffer(); }, 'Unable to create WebGLFramebuffer.'); + } + function bindVertexBufferToProgramAttribute(gl, debug, program, attribute, buffer, arrayEntriesPerItem, itemStrideInBytes, itemOffsetInBytes) { + var loc = gl.getAttribLocation(program, attribute); + if (loc === -1) { + // The GPU compiler decided to strip out this attribute because it's unused, + // thus no need to bind. + return false; + } + callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, buffer); }); + callAndCheck(gl, debug, function () { return gl.vertexAttribPointer(loc, arrayEntriesPerItem, gl.FLOAT, false, itemStrideInBytes, itemOffsetInBytes); }); + callAndCheck(gl, debug, function () { return gl.enableVertexAttribArray(loc); }); + return true; + } + function bindTextureUnit(gl, debug, texture, textureUnit) { + validateTextureUnit(gl, textureUnit); + callAndCheck(gl, debug, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); }); + callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); + } + function unbindTextureUnit(gl, debug, textureUnit) { + validateTextureUnit(gl, textureUnit); + callAndCheck(gl, debug, function () { return gl.activeTexture(gl.TEXTURE0 + textureUnit); }); + callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); + } + function getProgramUniformLocationOrThrow(gl, debug, program, uniformName) { + return throwIfNull(gl, debug, function () { return gl.getUniformLocation(program, uniformName); }, 'uniform "' + uniformName + '" not present in program.'); + } + function getProgramUniformLocation(gl, program, uniformName) { + return gl.getUniformLocation(program, uniformName); + } + function bindTextureToProgramUniformSampler(gl, debug, program, texture, uniformSamplerLocation, textureUnit) { + callAndCheck(gl, debug, function () { return bindTextureUnit(gl, debug, texture, textureUnit); }); + callAndCheck(gl, debug, function () { return gl.uniform1i(uniformSamplerLocation, textureUnit); }); + } + function bindCanvasToFramebuffer(gl, debug) { + callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); }); + callAndCheck(gl, debug, function () { return gl.viewport(0, 0, gl.canvas.width, gl.canvas.height); }); + callAndCheck(gl, debug, function () { return gl.scissor(0, 0, gl.canvas.width, gl.canvas.height); }); + } + function bindColorTextureToFramebuffer(gl, debug, texture, framebuffer) { + callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); }); + callAndCheck(gl, debug, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); }); + } + function unbindColorTextureFromFramebuffer(gl, debug, framebuffer) { + callAndCheck(gl, debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer); }); + callAndCheck(gl, debug, function () { return gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, null, 0); }); + } + function validateFramebuffer(gl) { + var status = gl.checkFramebufferStatus(gl.FRAMEBUFFER); + if (status !== gl.FRAMEBUFFER_COMPLETE) { + throw new Error('Error binding framebuffer: ' + getFramebufferErrorMessage(gl, status)); + } + } + function getFramebufferErrorMessage(gl, status) { + switch (status) { + case gl.FRAMEBUFFER_INCOMPLETE_ATTACHMENT: + return 'FRAMEBUFFER_INCOMPLETE_ATTACHMENT'; + case gl.FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT: + return 'FRAMEBUFFER_INCOMPLETE_MISSING_ATTACHMENT'; + case gl.FRAMEBUFFER_INCOMPLETE_DIMENSIONS: + return 'FRAMEBUFFER_INCOMPLETE_DIMENSIONS'; + case gl.FRAMEBUFFER_UNSUPPORTED: + return 'FRAMEBUFFER_UNSUPPORTED'; + default: + return "unknown error " + status; + } + } + function throwIfNull(gl, debug, returnTOrNull, failureMessage) { + var tOrNull = callAndCheck(gl, debug, function () { return returnTOrNull(); }); + if (tOrNull == null) { + throw new Error(failureMessage); + } + return tOrNull; + } + function validateTextureUnit(gl, textureUnit) { + var maxTextureUnit = gl.MAX_COMBINED_TEXTURE_IMAGE_UNITS - 1; + var glTextureUnit = textureUnit + gl.TEXTURE0; + if (glTextureUnit < gl.TEXTURE0 || glTextureUnit > maxTextureUnit) { + var textureUnitRange = "[gl.TEXTURE0, gl.TEXTURE" + maxTextureUnit + "]"; + throw new Error("textureUnit must be in " + textureUnitRange + "."); + } + } + function getBatchDim(shape, dimsToSkip) { + if (dimsToSkip === void 0) { dimsToSkip = 2; } + return sizeFromShape(shape.slice(0, shape.length - dimsToSkip)); + } + function getRowsCols(shape) { + if (shape.length === 0) { + throw Error('Cannot get rows and columns of an empty shape array.'); + } + return [ + shape.length > 1 ? shape[shape.length - 2] : 1, shape[shape.length - 1] + ]; + } + function getShapeAs3D(shape) { + var shapeAs3D = [1, 1, 1]; + var isScalar = shape.length === 0 || (shape.length === 1 && shape[0] === 1); + if (!isScalar) { + shapeAs3D = + [getBatchDim(shape)].concat(getRowsCols(shape)); + } + return shapeAs3D; + } + function getTextureShapeFromLogicalShape(logShape, isPacked) { + var _a; + if (isPacked === void 0) { isPacked = false; } + var maxTexSize = env().getNumber('WEBGL_MAX_TEXTURE_SIZE'); + if (isPacked) { + maxTexSize = maxTexSize * 2; + // This logic ensures we accurately count the number of packed texels needed + // to accommodate the tensor. We can only pack values in the same texel if + // they are from adjacent pairs of rows/cols within the same batch. So if a + // tensor has 3 rows, we pretend it has 4 rows in order to account for the + // fact that the texels containing the third row are half empty. + logShape = logShape.map(function (d, i) { return i >= logShape.length - 2 ? + nearestLargerEven(logShape[i]) : + logShape[i]; }); + // Packed texture height is at least 2 (the channel height of a single + // texel). + if (logShape.length === 1) { + logShape = [2, logShape[0]]; + } + } + // If logical shape is 2, we don't squeeze, since we want to match physical. + if (logShape.length !== 2) { + var squeezeResult = squeezeShape(logShape); + logShape = squeezeResult.newShape; + } + var size = sizeFromShape(logShape); + if (logShape.length <= 1 && size <= maxTexSize) { + return [1, size]; + } + else if (logShape.length === 2 && logShape[0] <= maxTexSize && + logShape[1] <= maxTexSize) { + return logShape; + } + else if (logShape.length === 3 && logShape[0] * logShape[1] <= maxTexSize && + logShape[2] <= maxTexSize) { + return [logShape[0] * logShape[1], logShape[2]]; + } + else if (logShape.length === 3 && logShape[0] <= maxTexSize && + logShape[1] * logShape[2] <= maxTexSize) { + return [logShape[0], logShape[1] * logShape[2]]; + } + else if (logShape.length === 4 && + logShape[0] * logShape[1] * logShape[2] <= maxTexSize && + logShape[3] <= maxTexSize) { + return [logShape[0] * logShape[1] * logShape[2], logShape[3]]; + } + else if (logShape.length === 4 && logShape[0] <= maxTexSize && + logShape[1] * logShape[2] * logShape[3] <= maxTexSize) { + return [logShape[0], logShape[1] * logShape[2] * logShape[3]]; + } + else { + if (isPacked) { + // For packed textures size equals the number of channels required to + // accommodate the texture data. However in order to squarify such that + // inner dimensions stay even, we rewrite size to equal the number of + // texels. Then in the return statement we rehydrate the squarified + // dimensions to channel units. + var batchDim = getBatchDim(logShape); + var rows = 2, cols = 2; + if (logShape.length) { + _a = getRowsCols(logShape), rows = _a[0], cols = _a[1]; + } + size = batchDim * (rows / 2) * (cols / 2); + return sizeToSquarishShape(size).map(function (d) { return d * 2; }); + } + return sizeToSquarishShape(size); + } + } + function isEven(n) { + return n % 2 === 0; + } + /** + * This determines whether reshaping a packed texture requires rearranging + * the data within the texture, assuming 2x2 packing. + */ + function isReshapeFree(shape1, shape2) { + shape1 = shape1.slice(-2); + shape2 = shape2.slice(-2); + if (arraysEqual(shape1, shape2)) { + return true; + } + if (!shape1.length || !shape2.length) { // One of the shapes is a scalar. + return true; + } + if (shape1[0] === 0 || shape1[1] === 0 || shape2[0] === 0 || + shape2[1] === 0) { + return true; + } + if (shape1.length !== shape2.length) { // One of the shapes is a vector. + var shape1Cols = shape1.slice(-1)[0]; + var shape2Cols = shape2.slice(-1)[0]; + if (shape1Cols === shape2Cols) { + return true; + } + if (isEven(shape1Cols) && isEven(shape2Cols) && + (shape1[0] === 1 || shape2[0] === 1)) { + return true; + } + } + return shape1[1] === shape2[1] && isEven(shape1[0]) && isEven(shape2[0]); + } + // We cache webgl params because the environment gets reset between + // unit tests and we don't want to constantly query the WebGLContext for + // MAX_TEXTURE_SIZE. + var MAX_TEXTURE_SIZE; + var MAX_TEXTURES_IN_SHADER; + function getWebGLMaxTextureSize(webGLVersion) { + if (MAX_TEXTURE_SIZE == null) { + var gl = getWebGLContext(webGLVersion); + MAX_TEXTURE_SIZE = gl.getParameter(gl.MAX_TEXTURE_SIZE); + } + return MAX_TEXTURE_SIZE; + } + function resetMaxTextureSize() { + MAX_TEXTURE_SIZE = null; + } + function resetMaxTexturesInShader() { + MAX_TEXTURES_IN_SHADER = null; + } + function getMaxTexturesInShader(webGLVersion) { + if (MAX_TEXTURES_IN_SHADER == null) { + var gl = getWebGLContext(webGLVersion); + MAX_TEXTURES_IN_SHADER = gl.getParameter(gl.MAX_TEXTURE_IMAGE_UNITS); + } + // We cap at 16 to avoid spurious runtime "memory exhausted" error. + return Math.min(16, MAX_TEXTURES_IN_SHADER); + } + function getWebGLDisjointQueryTimerVersion(webGLVersion) { + if (webGLVersion === 0) { + return 0; + } + var queryTimerVersion; + var gl = getWebGLContext(webGLVersion); + if (hasExtension(gl, 'EXT_disjoint_timer_query_webgl2') && + webGLVersion === 2) { + queryTimerVersion = 2; + } + else if (hasExtension(gl, 'EXT_disjoint_timer_query')) { + queryTimerVersion = 1; + } + else { + queryTimerVersion = 0; + } + return queryTimerVersion; + } + function hasExtension(gl, extensionName) { + var ext = gl.getExtension(extensionName); + return ext != null; + } + function isWebGLVersionEnabled(webGLVersion) { + try { + var gl = getWebGLContext(webGLVersion); + if (gl != null) { + return true; + } + } + catch (e) { + return false; + } + return false; + } + function isCapableOfRenderingToFloatTexture(webGLVersion) { + if (webGLVersion === 0) { + return false; + } + var gl = getWebGLContext(webGLVersion); + if (webGLVersion === 1) { + if (!hasExtension(gl, 'OES_texture_float')) { + return false; + } + } + else { + if (!hasExtension(gl, 'EXT_color_buffer_float')) { + return false; + } + } + var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); + return isFrameBufferComplete; + } + /** + * Check if we can download values from a float/half-float texture. + * + * Note that for performance reasons we use binding a texture to a framebuffer + * as a proxy for ability to download float values later using readPixels. The + * texture params of this texture will not match those in readPixels exactly + * but if we are unable to bind some kind of float texture to the frameBuffer + * then we definitely will not be able to read float values from it. + */ + function isDownloadFloatTextureEnabled(webGLVersion) { + if (webGLVersion === 0) { + return false; + } + var gl = getWebGLContext(webGLVersion); + if (webGLVersion === 1) { + if (!hasExtension(gl, 'OES_texture_float')) { + return false; + } + if (!hasExtension(gl, 'WEBGL_color_buffer_float')) { + return false; + } + } + else { + if (hasExtension(gl, 'EXT_color_buffer_float')) { + return createFloatTextureAndBindToFramebuffer(gl); + } + var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; + if (hasExtension(gl, COLOR_BUFFER_HALF_FLOAT)) { + var textureHalfFloatExtension = gl.getExtension(COLOR_BUFFER_HALF_FLOAT); + return createHalfFloatTextureAndBindToFramebuffer(gl, textureHalfFloatExtension); + } + return false; + } + var isFrameBufferComplete = createFloatTextureAndBindToFramebuffer(gl); + return isFrameBufferComplete; + } + function createFloatTextureAndBindToFramebuffer(gl) { + var texConfig = getTextureConfig(gl); + var texture = gl.createTexture(); + gl.bindTexture(gl.TEXTURE_2D, texture); + var width = 1; + var height = 1; + gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeFloat, null); + var frameBuffer = gl.createFramebuffer(); + gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); + gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); + var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; + gl.bindTexture(gl.TEXTURE_2D, null); + gl.bindFramebuffer(gl.FRAMEBUFFER, null); + gl.deleteTexture(texture); + gl.deleteFramebuffer(frameBuffer); + return isFrameBufferComplete; + } + function createHalfFloatTextureAndBindToFramebuffer( + // tslint:disable-next-line:no-any + gl, textureHalfFloatExtension) { + var texConfig = getTextureConfig(gl, textureHalfFloatExtension); + var texture = gl.createTexture(); + gl.bindTexture(gl.TEXTURE_2D, texture); + var width = 1; + var height = 1; + gl.texImage2D(gl.TEXTURE_2D, 0, texConfig.internalFormatHalfFloat, width, height, 0, texConfig.textureFormatFloat, texConfig.textureTypeHalfFloat, null); + var frameBuffer = gl.createFramebuffer(); + gl.bindFramebuffer(gl.FRAMEBUFFER, frameBuffer); + gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0); + var isFrameBufferComplete = gl.checkFramebufferStatus(gl.FRAMEBUFFER) === gl.FRAMEBUFFER_COMPLETE; + gl.bindTexture(gl.TEXTURE_2D, null); + gl.bindFramebuffer(gl.FRAMEBUFFER, null); + gl.deleteTexture(texture); + gl.deleteFramebuffer(frameBuffer); + return isFrameBufferComplete; + } + function isWebGLFenceEnabled(webGLVersion) { + if (webGLVersion !== 2) { + return false; + } + var gl = getWebGLContext(webGLVersion); + // tslint:disable-next-line:no-any + var isEnabled = gl.fenceSync != null; + return isEnabled; + } + + var webgl_util = /*#__PURE__*/Object.freeze({ + callAndCheck: callAndCheck, + canBeRepresented: canBeRepresented, + getWebGLErrorMessage: getWebGLErrorMessage, + getExtensionOrThrow: getExtensionOrThrow, + createVertexShader: createVertexShader, + createFragmentShader: createFragmentShader, + createProgram: createProgram, + linkProgram: linkProgram, + validateProgram: validateProgram, + createStaticVertexBuffer: createStaticVertexBuffer, + createStaticIndexBuffer: createStaticIndexBuffer, + getNumChannels: getNumChannels, + createTexture: createTexture, + validateTextureSize: validateTextureSize, + createFramebuffer: createFramebuffer, + bindVertexBufferToProgramAttribute: bindVertexBufferToProgramAttribute, + bindTextureUnit: bindTextureUnit, + unbindTextureUnit: unbindTextureUnit, + getProgramUniformLocationOrThrow: getProgramUniformLocationOrThrow, + getProgramUniformLocation: getProgramUniformLocation, + bindTextureToProgramUniformSampler: bindTextureToProgramUniformSampler, + bindCanvasToFramebuffer: bindCanvasToFramebuffer, + bindColorTextureToFramebuffer: bindColorTextureToFramebuffer, + unbindColorTextureFromFramebuffer: unbindColorTextureFromFramebuffer, + validateFramebuffer: validateFramebuffer, + getFramebufferErrorMessage: getFramebufferErrorMessage, + getBatchDim: getBatchDim, + getRowsCols: getRowsCols, + getShapeAs3D: getShapeAs3D, + getTextureShapeFromLogicalShape: getTextureShapeFromLogicalShape, + isReshapeFree: isReshapeFree, + getWebGLMaxTextureSize: getWebGLMaxTextureSize, + resetMaxTextureSize: resetMaxTextureSize, + resetMaxTexturesInShader: resetMaxTexturesInShader, + getMaxTexturesInShader: getMaxTexturesInShader, + getWebGLDisjointQueryTimerVersion: getWebGLDisjointQueryTimerVersion, + hasExtension: hasExtension, + isWebGLVersionEnabled: isWebGLVersionEnabled, + isCapableOfRenderingToFloatTexture: isCapableOfRenderingToFloatTexture, + isDownloadFloatTextureEnabled: isDownloadFloatTextureEnabled, + isWebGLFenceEnabled: isWebGLFenceEnabled + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ENV$1 = env(); + /** + * This file contains WebGL-specific flag registrations. + */ + /** + * True if WebGL is supported. + */ + ENV$1.registerFlag('HAS_WEBGL', function () { return ENV$1.getNumber('WEBGL_VERSION') > 0; }); + /** 0: No WebGL, 1: WebGL 1.0, 2: WebGL 2.0. */ + ENV$1.registerFlag('WEBGL_VERSION', function () { + if (isWebGLVersionEnabled(2)) { + return 2; + } + else if (isWebGLVersionEnabled(1)) { + return 1; + } + return 0; + }); + ENV$1.registerFlag('WEBGL_BUFFER_SUPPORTED', function () { return ENV$1.get('WEBGL_VERSION') === 2; }); + /** Whether the WebGL backend will sometimes forward ops to the CPU. */ + ENV$1.registerFlag('WEBGL_CPU_FORWARD', function () { return true; }); + /** Whether the WebGL backend will always use f16 textures for rendering. */ + ENV$1.registerFlag('WEBGL_FORCE_F16_TEXTURES', function () { return false; }); + /** Whether to turn all packing related flags on. */ + ENV$1.registerFlag('WEBGL_PACK', function () { return ENV$1.getBool('HAS_WEBGL'); }); + /** Whether we will pack the batchnormalization op. */ + ENV$1.registerFlag('WEBGL_PACK_NORMALIZATION', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether we will pack the clip op. */ + ENV$1.registerFlag('WEBGL_PACK_CLIP', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether we will pack the depthwise conv op. */ + // TODO: https://github.com/tensorflow/tfjs/issues/1679 + ENV$1.registerFlag('WEBGL_PACK_DEPTHWISECONV', function () { return false; }); + /** Whether we will pack binary ops. */ + ENV$1.registerFlag('WEBGL_PACK_BINARY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether we will pack unary ops. */ + ENV$1.registerFlag('WEBGL_PACK_UNARY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether we will pack array ops. */ + ENV$1.registerFlag('WEBGL_PACK_ARRAY_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether we will pack image ops. */ + ENV$1.registerFlag('WEBGL_PACK_IMAGE_OPERATIONS', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether we will pack reduce ops. */ + ENV$1.registerFlag('WEBGL_PACK_REDUCE', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether packed WebGL kernels lazily unpack their outputs. */ + ENV$1.registerFlag('WEBGL_LAZILY_UNPACK', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** Whether we will use the im2col algorithm to speed up convolutions. */ + ENV$1.registerFlag('WEBGL_CONV_IM2COL', function () { return ENV$1.getBool('WEBGL_PACK'); }); + /** The maximum texture dimension. */ + ENV$1.registerFlag('WEBGL_MAX_TEXTURE_SIZE', function () { return getWebGLMaxTextureSize(ENV$1.getNumber('WEBGL_VERSION')); }); + /** The maximum texture dimension. */ + ENV$1.registerFlag('WEBGL_MAX_TEXTURES_IN_SHADER', function () { return getMaxTexturesInShader(ENV$1.getNumber('WEBGL_VERSION')); }); + /** + * The disjoint_query_timer extension version. + * 0: disabled, 1: EXT_disjoint_timer_query, 2: + * EXT_disjoint_timer_query_webgl2. + * In Firefox with WebGL 2.0, + * EXT_disjoint_timer_query_webgl2 is not available, so we must use the + * WebGL 1.0 extension. + */ + ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION', function () { + var webGLVersion = ENV$1.getNumber('WEBGL_VERSION'); + if (webGLVersion === 0) { + return 0; + } + return getWebGLDisjointQueryTimerVersion(webGLVersion); + }); + /** + * Whether the timer object from the disjoint_query_timer extension gives + * timing information that is reliable. + */ + ENV$1.registerFlag('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_RELIABLE', function () { return ENV$1.getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0 && + !isMobile(); }); + /** + * Whether the device is physically capable of rendering to float32 textures. + */ + ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_CAPABLE', function () { return isCapableOfRenderingToFloatTexture(ENV$1.getNumber('WEBGL_VERSION')); }); + /** + * Whether rendering to float32 textures is enabled. If disabled, renders to + * float16 textures. + */ + ENV$1.registerFlag('WEBGL_RENDER_FLOAT32_ENABLED', function () { + return ENV$1.getBool('WEBGL_FORCE_F16_TEXTURES') ? + false : + ENV$1.getBool('WEBGL_RENDER_FLOAT32_CAPABLE'); + }); + /** + * Whether downloading float textures is enabled (16 or 32 bit). If disabled, + * uses IEEE 754 encoding of the float32 values to 4 uint8 when downloading. + */ + ENV$1.registerFlag('WEBGL_DOWNLOAD_FLOAT_ENABLED', function () { return isDownloadFloatTextureEnabled(ENV$1.getNumber('WEBGL_VERSION')); }); + /** Whether the fence API is available. */ + ENV$1.registerFlag('WEBGL_FENCE_API_ENABLED', function () { return isWebGLFenceEnabled(ENV$1.getNumber('WEBGL_VERSION')); }); + /** + * Tensors with size <= than this will be uploaded as uniforms, not textures. + */ + ENV$1.registerFlag('WEBGL_SIZE_UPLOAD_UNIFORM', function () { + // Use uniform uploads only when 32bit floats are supported. In + // 16bit + // environments there are problems with comparing a 16bit texture value + // with a 32bit uniform value. + var useUniforms = ENV$1.getBool('WEBGL_RENDER_FLOAT32_ENABLED'); + return useUniforms ? 4 : 0; + }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Enables production mode which disables correctness checks in favor of + * performance. + */ + /** @doc {heading: 'Environment'} */ + function enableProdMode() { + env().set('PROD', true); + } + /** + * Enables debug mode which will log information about all executed kernels: + * the elapsed time of the kernel execution, as well as the rank, shape, and + * size of the output tensor. + * + * Debug mode will significantly slow down your application as it will + * download the result of every operation to the CPU. This should not be used in + * production. Debug mode does not affect the timing information of the kernel + * execution as we do not measure download time in the kernel execution time. + * + * See also: `tf.profile`, `tf.memory`. + */ + /** @doc {heading: 'Environment'} */ + function enableDebugMode() { + env().set('DEBUG', true); + } + /** Globally disables deprecation warnings */ + function disableDeprecationWarnings() { + env().set('DEPRECATION_WARNINGS_ENABLED', false); + console.warn("TensorFlow.js deprecation warnings have been disabled."); + } + /** Warn users about deprecated functionality. */ + function deprecationWarn(msg) { + if (env().getBool('DEPRECATION_WARNINGS_ENABLED')) { + console.warn(msg + ' You can disable deprecation warnings with ' + + 'tf.disableDeprecationWarnings().'); + } + } + setDeprecationWarningFn(deprecationWarn); + /** + * Dispose all variables kept in backend engine. + */ + /** @doc {heading: 'Environment'} */ + function disposeVariables() { + ENGINE.disposeVariables(); + } + /** + * It returns the global engine that keeps track of all tensors and backends. + */ + /** @doc {heading: 'Environment'} */ + function engine() { + return ENGINE; + } + /** + * Returns memory info at the current time in the program. The result is an + * object with the following properties: + * + * - `numBytes`: Number of bytes allocated (undisposed) at this time. + * - `numTensors`: Number of unique tensors allocated. + * - `numDataBuffers`: Number of unique data buffers allocated + * (undisposed) at this time, which is ≤ the number of tensors + * (e.g. `a.reshape(newShape)` makes a new Tensor that shares the same + * data buffer with `a`). + * - `unreliable`: True if the memory usage is unreliable. See `reasons` when + * `unreliable` is true. + * - `reasons`: `string[]`, reasons why the memory is unreliable, present if + * `unreliable` is true. + * + * WebGL Properties: + * - `numBytesInGPU`: Number of bytes allocated (undisposed) in the GPU only at + * this time. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function memory() { + return ENGINE.memory(); + } + /** + * Executes the provided function `f()` and returns a promise that resolves + * with information about the function's memory use: + * - `newBytes`: tne number of new bytes allocated + * - `newTensors`: the number of new tensors created + * - `peakBytes`: the peak number of bytes allocated + * - `kernels`: an array of objects for each kernel involved that reports + * their input and output shapes, number of bytes used, and number of new + * tensors created. + * + * ```js + * const profile = await tf.profile(() => { + * const x = tf.tensor1d([1, 2, 3]); + * let x2 = x.square(); + * x2.dispose(); + * x2 = x.square(); + * x2.dispose(); + * return x; + * }); + * + * console.log(`newBytes: ${profile.newBytes}`); + * console.log(`newTensors: ${profile.newTensors}`); + * console.log(`byte usage over all kernels: ${profile.kernels.map(k => + * k.totalBytesSnapshot)}`); + * ``` + * + */ + /** @doc {heading: 'Performance', subheading: 'Profile'} */ + function profile(f) { + return ENGINE.profile(f); + } + /** + * Executes the provided function `fn` and after it is executed, cleans up all + * intermediate tensors allocated by `fn` except those returned by `fn`. + * `fn` must not return a Promise (async functions not allowed). The returned + * result can be a complex object. + * + * Using this method helps avoid memory leaks. In general, wrap calls to + * operations in `tf.tidy` for automatic memory cleanup. + * + * NOTE: Variables do *not* get cleaned up when inside a tidy(). If you want to + * dispose variables, please use `tf.disposeVariables` or call dispose() + * directly on variables. + * + * ```js + * // y = 2 ^ 2 + 1 + * const y = tf.tidy(() => { + * // a, b, and one will be cleaned up when the tidy ends. + * const one = tf.scalar(1); + * const a = tf.scalar(2); + * const b = a.square(); + * + * console.log('numTensors (in tidy): ' + tf.memory().numTensors); + * + * // The value returned inside the tidy function will return + * // through the tidy, in this case to the variable y. + * return b.add(one); + * }); + * + * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); + * y.print(); + * ``` + * + * @param nameOrFn The name of the closure, or the function to execute. + * If a name is provided, the 2nd argument should be the function. + * If debug mode is on, the timing and the memory usage of the function + * will be tracked and displayed on the console using the provided name. + * @param fn The function to execute. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function tidy(nameOrFn, fn) { + return ENGINE.tidy(nameOrFn, fn); + } + /** + * Disposes any `tf.Tensor`s found within the provided object. + * + * @param container an object that may be a `tf.Tensor` or may directly + * contain `tf.Tensor`s, such as a `Tensor[]` or `{key: Tensor, ...}`. If + * the object is not a `tf.Tensor` or does not contain `Tensors`, nothing + * happens. In general it is safe to pass any object here, except that + * `Promise`s are not supported. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function dispose(container) { + var tensors = getTensorsInContainer(container); + tensors.forEach(function (tensor) { return tensor.dispose(); }); + } + /** + * Keeps a `tf.Tensor` generated inside a `tf.tidy` from being disposed + * automatically. + * + * ```js + * let b; + * const y = tf.tidy(() => { + * const one = tf.scalar(1); + * const a = tf.scalar(2); + * + * // b will not be cleaned up by the tidy. a and one will be cleaned up + * // when the tidy ends. + * b = tf.keep(a.square()); + * + * console.log('numTensors (in tidy): ' + tf.memory().numTensors); + * + * // The value returned inside the tidy function will return + * // through the tidy, in this case to the variable y. + * return b.add(one); + * }); + * + * console.log('numTensors (outside tidy): ' + tf.memory().numTensors); + * console.log('y:'); + * y.print(); + * console.log('b:'); + * b.print(); + * ``` + * + * @param result The tensor to keep from being disposed. + */ + /** @doc {heading: 'Performance', subheading: 'Memory'} */ + function keep(result) { + return ENGINE.keep(result); + } + /** + * Executes `f()` and returns a promise that resolves with timing + * information. + * + * The result is an object with the following properties: + * + * - `wallMs`: Wall execution time. + * - `kernelMs`: Kernel execution time, ignoring data transfer. + * - On `WebGL` The following additional properties exist: + * - `uploadWaitMs`: CPU blocking time on texture uploads. + * - `downloadWaitMs`: CPU blocking time on texture downloads (readPixels). + * + * ```js + * const x = tf.randomNormal([20, 20]); + * const time = await tf.time(() => x.matMul(x)); + * + * console.log(`kernelMs: ${time.kernelMs}, wallTimeMs: ${time.wallMs}`); + * ``` + * + * @param f The function to execute and time. + */ + /** @doc {heading: 'Performance', subheading: 'Timing'} */ + function time(f) { + return ENGINE.time(f); + } + /** + * Sets the backend (cpu, webgl, etc) responsible for creating tensors and + * executing operations on those tensors. Returns a promise that resolves + * to a boolean if the backend initialization was successful. + * + * Note this disposes the current backend, if any, as well as any tensors + * associated with it. A new backend is initialized, even if it is of the + * same type as the previous one. + * + * @param backendName The name of the backend. Currently supports + * `'webgl'|'cpu'` in the browser, and `'tensorflow'` under node.js + * (requires tfjs-node). + */ + /** @doc {heading: 'Backends'} */ + function setBackend(backendName) { + return ENGINE.setBackend(backendName); + } + /** + * Returns a promise that resolves when the currently selected backend (or the + * highest priority one) has initialized. Await this promise when you are using + * a backend that has async initialization. + */ + /** @doc {heading: 'Backends'} */ + function ready() { + return ENGINE.ready(); + } + /** + * Returns the current backend name (cpu, webgl, etc). The backend is + * responsible for creating tensors and executing operations on those tensors. + */ + /** @doc {heading: 'Backends'} */ + function getBackend() { + return ENGINE.backendName; + } + /** + * Removes a backend and the registered factory. + */ + /** @doc {heading: 'Backends'} */ + function removeBackend(name) { + ENGINE.removeBackend(name); + } + /** + * Finds the backend registered under the provided name. Returns null if the + * name is not in the registry, or the registration hasn't finished yet. + */ + function findBackend(name) { + return ENGINE.findBackend(name); + } + /** + * Finds the backend factory registered under the provided name. Returns a + * function that produces a new backend when called. Returns null if the name + * is not in the registry. + */ + function findBackendFactory(name) { + return ENGINE.findBackendFactory(name); + } + /** + * Registers a global backend. The registration should happen when importing + * a module file (e.g. when importing `backend_webgl.ts`), and is used for + * modular builds (e.g. custom tfjs bundle with only webgl support). + * + * @param factory The backend factory function. When called, it should + * return a backend instance, or a promise of an instance. + * @param priority The priority of the backend (higher = more important). + * In case multiple backends are registered, the priority is used to find + * the best backend. Defaults to 1. + * @return False if there is already a registered backend under this name, true + * if not. + */ + /** @doc {heading: 'Backends'} */ + function registerBackend(name, factory, priority) { + if (priority === void 0) { priority = 1; } + return ENGINE.registerBackend(name, factory, priority); + } + /** + * Gets the current backend. If no backends have been initialized, this will + * attempt to initialize the best backend. Will throw an error if the highest + * priority backend has async initialization, in which case, you should call + * 'await tf.ready()' before running other code. + */ + /** @doc {heading: 'Backends'} */ + function backend() { + return ENGINE.backend; + } + /** + * Sets the global platform. + * + * @param platformName The name of this platform. + * @param platform A platform implementation. + */ + function setPlatform(platformName, platform) { + env().setPlatform(platformName, platform); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function warn() { + var msg = []; + for (var _i = 0; _i < arguments.length; _i++) { + msg[_i] = arguments[_i]; + } + if (!env().getBool('IS_TEST')) { + console.warn.apply(console, msg); + } + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function inferShape(val, dtype) { + var firstElem = val; + if (isTypedArray(val)) { + return dtype === 'string' ? [] : [val.length]; + } + if (!Array.isArray(val)) { + return []; // Scalar. + } + var shape = []; + while (Array.isArray(firstElem) || + isTypedArray(firstElem) && dtype !== 'string') { + shape.push(firstElem.length); + firstElem = firstElem[0]; + } + if (Array.isArray(val) && + env().getBool('TENSORLIKE_CHECK_SHAPE_CONSISTENCY')) { + deepAssertShapeConsistency(val, shape, []); + } + return shape; + } + function deepAssertShapeConsistency(val, shape, indices) { + indices = indices || []; + if (!(Array.isArray(val)) && !isTypedArray(val)) { + assert(shape.length === 0, function () { return "Element arr[" + indices.join('][') + "] is a primitive, " + + ("but should be an array/TypedArray of " + shape[0] + " elements"); }); + return; + } + assert(shape.length > 0, function () { return "Element arr[" + indices.join('][') + "] should be a primitive, " + + ("but is an array of " + val.length + " elements"); }); + assert(val.length === shape[0], function () { return "Element arr[" + indices.join('][') + "] should have " + shape[0] + " " + + ("elements, but has " + val.length + " elements"); }); + var subShape = shape.slice(1); + for (var i = 0; i < val.length; ++i) { + deepAssertShapeConsistency(val[i], subShape, indices.concat(i)); + } + } + function assertDtype(expectedDtype, actualDType, argName, functionName) { + if (expectedDtype == null) { + return; + } + if (expectedDtype !== 'numeric' && expectedDtype !== actualDType || + expectedDtype === 'numeric' && actualDType === 'string') { + throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must " + + ("be " + expectedDtype + " tensor, but got " + actualDType + " tensor")); + } + } + function convertToTensor(x, argName, functionName, parseAsDtype) { + if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } + if (x instanceof Tensor) { + assertDtype(parseAsDtype, x.dtype, argName, functionName); + return x; + } + var inferredDtype = inferDtype(x); + // If the user expects a bool/int/float, use that info to update the + // inferredDtype when it is not a string. + if (inferredDtype !== 'string' && + ['bool', 'int32', 'float32'].indexOf(parseAsDtype) >= 0) { + inferredDtype = parseAsDtype; + } + assertDtype(parseAsDtype, inferredDtype, argName, functionName); + if ((x == null) || + (!isTypedArray(x) && !Array.isArray(x) && typeof x !== 'number' && + typeof x !== 'boolean' && typeof x !== 'string')) { + var type = x == null ? 'null' : x.constructor.name; + throw new Error("Argument '" + argName + "' passed to '" + functionName + "' must be a " + + ("Tensor or TensorLike, but got '" + type + "'")); + } + var inferredShape = inferShape(x, inferredDtype); + if (!isTypedArray(x) && !Array.isArray(x)) { + x = [x]; + } + var skipTypedArray = true; + var values = inferredDtype !== 'string' ? + toTypedArray(x, inferredDtype, env().getBool('DEBUG')) : + flatten(x, [], skipTypedArray); + return ENGINE.makeTensor(values, inferredShape, inferredDtype); + } + function convertToTensorArray(arg, argName, functionName, parseAsDtype) { + if (parseAsDtype === void 0) { parseAsDtype = 'numeric'; } + if (!Array.isArray(arg)) { + throw new Error("Argument " + argName + " passed to " + functionName + " must be a " + + '`Tensor[]` or `TensorLike[]`'); + } + var tensors = arg; + return tensors.map(function (t, i) { return convertToTensor(t, argName + "[" + i + "]", functionName); }, parseAsDtype); + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns true if the axis specifies the inner most dimensions of the + * array. + */ + function axesAreInnerMostDims(axes, rank) { + for (var i = 0; i < axes.length; ++i) { + if (axes[axes.length - i - 1] !== rank - 1 - i) { + return false; + } + } + return true; + } + function combineLocations(outputLoc, reduceLoc, axes) { + var rank = outputLoc.length + reduceLoc.length; + var loc = []; + var outIdx = 0; + var reduceIdx = 0; + for (var dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + loc.push(outputLoc[outIdx++]); + } + else { + loc.push(reduceLoc[reduceIdx++]); + } + } + return loc; + } + function computeOutAndReduceShapes(aShape, axes) { + var outShape = []; + var rank = aShape.length; + for (var dim = 0; dim < rank; dim++) { + if (axes.indexOf(dim) === -1) { + outShape.push(aShape[dim]); + } + } + var reduceShape = axes.map(function (dim) { return aShape[dim]; }); + return [outShape, reduceShape]; + } + function expandShapeToKeepDim(shape, axes) { + var reduceSubShape = axes.map(function (x) { return 1; }); + return combineLocations(shape, reduceSubShape, axes); + } + function assertAxesAreInnerMostDims(msg, axes, rank) { + assert(axesAreInnerMostDims(axes, rank), function () { return msg + " supports only inner-most axes for now. " + + ("Got axes " + axes + " and rank-" + rank + " input."); }); + } + /** + * Returns the axes permutation to be used with `tf.transpose`, if such + * permutation is necessary. Otherwise it returns null. This method is used by + * operations that operate only on inner-most axes. + */ + function getAxesPermutation(axes, rank) { + if (axesAreInnerMostDims(axes, rank)) { + return null; + } + var result = []; + for (var i = 0; i < rank; ++i) { + if (axes.indexOf(i) === -1) { + result.push(i); + } + } + axes.forEach(function (axis) { return result.push(axis); }); + return result; + } + /** Returns the axes permutation that undoes the original permutation. */ + function getUndoAxesPermutation(axes) { + return axes.map(function (axis, i) { return [i, axis]; }) + .sort(function (a, b) { return a[1] - b[1]; }) + .map(function (x) { return x[0]; }); + } + function getInnerMostAxes(numAxes, rank) { + var res = []; + for (var i = rank - numAxes; i < rank; ++i) { + res.push(i); + } + return res; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function assertParamsConsistent(shapes, axis) { + var rank = shapes[0].length; + shapes.forEach(function (shape, i) { + assert(shape.length === rank, function () { + return "Error in concat" + rank + "D: rank of tensors[" + i + "] must be the same " + + ("as the rank of the rest (" + rank + ")"); + }); + }); + assert(axis >= 0 && axis < rank, function () { return "Error in concat" + rank + "D: axis must be between 0 and " + (rank - 1) + "."; }); + var firstShape = shapes[0]; + shapes.forEach(function (shape, i) { + for (var r = 0; r < rank; r++) { + assert((r === axis) || (shape[r] === firstShape[r]), function () { return "Error in concat" + rank + "D: Shape of tensors[" + i + "] (" + shape + ") " + + ("does not match the shape of the rest (" + firstShape + ") ") + + ("along the non-concatenated axis " + i + "."); }); + } + }); + } + function computeOutShape(shapes, axis) { + var outputShape = shapes[0].slice(); + for (var i = 1; i < shapes.length; i++) { + outputShape[axis] += shapes[i][axis]; + } + return outputShape; + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Used for wrapping functions that perform math operations on + * Tensors. The function will be wrapped in a named scope that cleans all + * memory usage after the function is done. + */ + function op(f) { + var keys = Object.keys(f); + if (keys.length !== 1) { + throw new Error("Please provide an object with a single key " + + "(operation name) mapping to a function. Got an object with " + + (keys.length + " keys.")); + } + var opName = keys[0]; + var fn = f[opName]; + // Strip the underscore from the end of the function name. + if (opName.endsWith('_')) { + opName = opName.substring(0, opName.length - 1); + } + // tslint:disable-next-line:no-any + var f2 = function () { + var args = []; + for (var _i = 0; _i < arguments.length; _i++) { + args[_i] = arguments[_i]; + } + ENGINE.startScope(opName); + try { + var result = fn.apply(void 0, args); + if (result instanceof Promise) { + console.error('Cannot return a Promise inside of tidy.'); + } + ENGINE.endScope(result); + return result; + } + catch (ex) { + ENGINE.endScope(null); + throw ex; + } + }; + Object.defineProperty(f2, 'name', { value: opName, configurable: true }); + // tslint:disable-next-line:no-any + return f2; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Converts two real numbers to a complex number. + * + * Given a tensor `real` representing the real part of a complex number, and a + * tensor `imag` representing the imaginary part of a complex number, this + * operation returns complex numbers elementwise of the form [r0, i0, r1, i1], + * where r represents the real part and i represents the imag part. + * + * The input tensors real and imag must have the same shape. + * + * ```js + * const real = tf.tensor1d([2.25, 3.25]); + * const imag = tf.tensor1d([4.75, 5.75]); + * const complex = tf.complex(real, imag); + * + * complex.print(); + * ``` + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function complex_(real, imag) { + var $real = convertToTensor(real, 'real', 'complex'); + var $imag = convertToTensor(imag, 'imag', 'complex'); + assertShapesMatch($real.shape, $imag.shape, "real and imag shapes, " + $real.shape + " and " + $imag.shape + ", " + + "must match in call to tf.complex()."); + return ENGINE.runKernelFunc(function (backend) { return backend.complex($real, $imag); }, { $real: $real, $imag: $imag }); + } + /** + * Returns the real part of a complex (or real) tensor. + * + * Given a tensor input, this operation returns a tensor of type float that is + * the real part of each element in input considered as a complex number. + * + * If the input is real, it simply makes a clone. + * + * ```js + * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); + * tf.real(x).print(); + * ``` + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function real_(input) { + var $input = convertToTensor(input, 'input', 'real'); + return ENGINE.runKernelFunc(function (backend) { return backend.real($input); }, { $input: $input }); + } + /** + * Returns the imaginary part of a complex (or real) tensor. + * + * Given a tensor input, this operation returns a tensor of type float that is + * the imaginary part of each element in input considered as a complex number. + * If input is real, a tensor of all zeros is returned. + * + * ```js + * const x = tf.complex([-2.25, 3.25], [4.75, 5.75]); + * tf.imag(x).print(); + * ``` + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function imag_(input) { + var $input = convertToTensor(input, 'input', 'imag'); + return ENGINE.runKernelFunc(function (backend) { return backend.imag($input); }, { $input: $input }); + } + var complex = op({ complex_: complex_ }); + var real = op({ real_: real_ }); + var imag = op({ imag_: imag_ }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Creates a `tf.Tensor` with the provided values, shape and dtype. + * + * ```js + * // Pass an array of values to create a vector. + * tf.tensor([1, 2, 3, 4]).print(); + * ``` + * + * ```js + * // Pass a nested array of values to make a matrix or a higher + * // dimensional tensor. + * tf.tensor([[1, 2], [3, 4]]).print(); + * ``` + * + * ```js + * // Pass a flat array and specify a shape yourself. + * tf.tensor([1, 2, 3, 4], [2, 2]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. If the values are strings, + * they will be encoded as utf-8 and kept as `Uint8Array[]`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor(values, shape, dtype) { + var inferredShape = inferShape(values, dtype); + return makeTensor(values, shape, inferredShape, dtype); + } + /** This is shared code across all tensor creation methods. */ + function makeTensor(values, shape, inferredShape, dtype) { + if (dtype == null) { + dtype = inferDtype(values); + } + if (dtype === 'complex64') { + throw new Error("Cannot construct a complex64 tensor directly. " + + "Please use tf.complex(real, imag)."); + } + if (!isTypedArray(values) && !Array.isArray(values) && + typeof values !== 'number' && typeof values !== 'boolean' && + typeof values !== 'string') { + throw new Error('values passed to tensor(values) must be a number/boolean/string or ' + + 'an array of numbers/booleans/strings, or a TypedArray'); + } + if (shape != null) { + assertNonNegativeIntegerDimensions(shape); + var providedSize_1 = sizeFromShape(shape); + var inferredSize_1 = sizeFromShape(inferredShape); + assert(providedSize_1 === inferredSize_1, function () { + return "Based on the provided shape, [" + shape + "], the tensor should have " + + (providedSize_1 + " values but has " + inferredSize_1); + }); + for (var i = 0; i < inferredShape.length; ++i) { + var inferred = inferredShape[i]; + var flatDimsDontMatch = i === inferredShape.length - 1 ? + inferred !== sizeFromShape(shape.slice(i)) : + true; + assert(inferredShape[i] === shape[i] || !flatDimsDontMatch, function () { return "Error creating a new Tensor. Inferred shape " + + ("(" + inferredShape + ") does not match the provided ") + + ("shape (" + shape + "). "); }); + } + } + if (!isTypedArray(values) && !Array.isArray(values)) { + values = [values]; + } + shape = shape || inferredShape; + values = dtype !== 'string' ? + toTypedArray(values, dtype, env().getBool('DEBUG')) : + flatten(values, [], true); + return ENGINE.makeTensor(values, shape, dtype); + } + /** + * Creates rank-0 `tf.Tensor` (scalar) with the provided value and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.scalar` as it makes the code more readable. + * + * ```js + * tf.scalar(3.14).print(); + * ``` + * + * @param value The value of the scalar. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function scalar(value, dtype) { + if (((isTypedArray(value) && dtype !== 'string') || Array.isArray(value)) && + dtype !== 'complex64') { + throw new Error('Error creating a new Scalar: value must be a primitive ' + + '(number|boolean|string)'); + } + if (dtype === 'string' && isTypedArray(value) && + !(value instanceof Uint8Array)) { + throw new Error('When making a scalar from encoded string, ' + + 'the value must be `Uint8Array`.'); + } + var shape = []; + var inferredShape = []; + return makeTensor(value, shape, inferredShape, dtype); + } + /** + * Creates rank-1 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor1d` as it makes the code more readable. + * + * ```js + * tf.tensor1d([1, 2, 3]).print(); + * ``` + * + * @param values The values of the tensor. Can be array of numbers, + * or a `TypedArray`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor1d(values, dtype) { + assertNonNull(values); + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 1) { + throw new Error('tensor1d() requires values to be a flat/TypedArray'); + } + var shape = null; + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-2 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor2d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor2d([[1, 2], [3, 4]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. If not provided, it is inferred from + * `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor2d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 2) { + throw new Error('tensor2d() requires shape to have two numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 2 && inferredShape.length !== 1) { + throw new Error('tensor2d() requires values to be number[][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor2d() requires shape to be provided when `values` ' + + 'are a flat/TypedArray'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-3 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor3d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor3d([[[1], [2]], [[3], [4]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor3d([1, 2, 3, 4], [2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. If not provided, it is inferred from + * `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor3d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 3) { + throw new Error('tensor3d() requires shape to have three numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 3 && inferredShape.length !== 1) { + throw new Error('tensor3d() requires values to be number[][][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor3d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-4 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor4d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor4d([[[[1], [2]], [[3], [4]]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor4d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 4) { + throw new Error('tensor4d() requires shape to have four numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 4 && inferredShape.length !== 1) { + throw new Error('tensor4d() requires values to be number[][][][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor4d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-5 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor5d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor5d([[[[[1], [2]], [[3], [4]]]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor5d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 5) { + throw new Error('tensor5d() requires shape to have five numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 5 && inferredShape.length !== 1) { + throw new Error('tensor5d() requires values to be ' + + 'number[][][][][] or flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor5d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates rank-6 `tf.Tensor` with the provided values, shape and dtype. + * + * The same functionality can be achieved with `tf.tensor`, but in general + * we recommend using `tf.tensor6d` as it makes the code more readable. + * + * ```js + * // Pass a nested array. + * tf.tensor6d([[[[[[1],[2]],[[3],[4]]],[[[5],[6]],[[7],[8]]]]]]).print(); + * ``` + * ```js + * // Pass a flat array and specify a shape. + * tf.tensor6d([1, 2, 3, 4, 5, 6, 7, 8], [1, 1, 2, 2, 2, 1]).print(); + * ``` + * + * @param values The values of the tensor. Can be nested array of numbers, + * or a flat array, or a `TypedArray`. + * @param shape The shape of the tensor. Optional. If not provided, + * it is inferred from `values`. + * @param dtype The data type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function tensor6d(values, shape, dtype) { + assertNonNull(values); + if (shape != null && shape.length !== 6) { + throw new Error('tensor6d() requires shape to have six numbers'); + } + var inferredShape = inferShape(values, dtype); + if (inferredShape.length !== 6 && inferredShape.length !== 1) { + throw new Error('tensor6d() requires values to be number[][][][][][] or ' + + 'flat/TypedArray'); + } + if (inferredShape.length === 1 && shape == null) { + throw new Error('tensor6d() requires shape to be provided when `values` ' + + 'are a flat array'); + } + shape = shape || + inferredShape; + return makeTensor(values, shape, inferredShape, dtype); + } + /** + * Creates a new variable with the provided initial value. + * ```js + * const x = tf.variable(tf.tensor([1, 2, 3])); + * x.assign(tf.tensor([4, 5, 6])); + * + * x.print(); + * ``` + * + * @param initialValue Initial value for the tensor. + * @param trainable If true, optimizers are allowed to update it. + * @param name Name of the variable. Defaults to a unique id. + * @param dtype If set, initialValue will be converted to the given type. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function variable(initialValue, trainable, name, dtype) { + if (trainable === void 0) { trainable = true; } + return ENGINE.makeVariable(initialValue, trainable, name, dtype); + } + /** + * Creates a `tf.Tensor` with all elements set to 1. + * + * ```js + * tf.ones([2, 2]).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param dtype The type of an element in the resulting tensor. Defaults to + * 'float'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function ones$1(shape, dtype) { + if (dtype === void 0) { dtype = 'float32'; } + if (dtype === 'complex64') { + var real_1 = ones$1(shape, 'float32'); + var imag_1 = zeros(shape, 'float32'); + return complex(real_1, imag_1); + } + var values = makeOnesTypedArray(sizeFromShape(shape), dtype); + return ENGINE.makeTensor(values, shape, dtype); + } + /** + * Creates a `tf.Tensor` with all elements set to 0. + * + * ```js + * tf.zeros([2, 2]).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param dtype The type of an element in the resulting tensor. Can + * be 'float32', 'int32' or 'bool'. Defaults to 'float'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function zeros(shape, dtype) { + if (dtype === void 0) { dtype = 'float32'; } + if (dtype === 'complex64') { + var real_2 = zeros(shape, 'float32'); + var imag_2 = zeros(shape, 'float32'); + return complex(real_2, imag_2); + } + var values = makeZerosTypedArray(sizeFromShape(shape), dtype); + return ENGINE.makeTensor(values, shape, dtype); + } + /** + * Creates a `tf.Tensor` filled with a scalar value. + * + * ```js + * tf.fill([2, 2], 4).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param value The scalar value to fill the tensor with. + * @param dtype The type of an element in the resulting tensor. Defaults to + * 'float'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function fill(shape, value, dtype) { + return ENGINE.runKernelFunc(function (backend) { return backend.fill(shape, value, dtype); }, {}); + } + /** + * Creates a `tf.Tensor` with all elements set to 1 with the same shape as the + * given tensor. + * + * ```js + * const x = tf.tensor([1, 2]); + * tf.onesLike(x).print(); + * ``` + * @param x A tensor. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function onesLike_(x) { + var $x = convertToTensor(x, 'x', 'onesLike'); + if ($x.dtype === 'complex64') { + var r = onesLike(real($x)); + var i = zerosLike(imag($x)); + return complex(r, i); + } + var der = function (dy, saved) { return ({ $x: function () { return zerosLike(dy); } }); }; + return ENGINE.runKernelFunc(function (backend) { return backend.onesLike($x); }, { $x: $x }, der); + } + /** + * Creates a `tf.Tensor` with all elements set to 0 with the same shape as the + * given tensor. + * + * ```js + * const x = tf.tensor([1, 2]); + * tf.zerosLike(x).print(); + * ``` + * + * @param x The tensor of required shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function zerosLike_(x) { + var $x = convertToTensor(x, 'x', 'zerosLike'); + var der = function (dy, saved) { return ({ $x: function () { return zerosLike(dy); } }); }; + return ENGINE.runKernelFunc(function (backend) { return backend.zerosLike($x); }, { $x: $x }, der); + } + /** + * Return an evenly spaced sequence of numbers over the given interval. + * + * ```js + * tf.linspace(0, 9, 10).print(); + * ``` + * @param start The start value of the sequence. + * @param stop The end value of the sequence. + * @param num The number of values to generate. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function linspace(start, stop, num) { + if (num <= 0) { + throw new Error('The number of values should be positive.'); + } + return ENGINE.runKernelFunc(function (backend) { return backend.linspace(start, stop, num); }, {}); + } + /** + * Creates a new `tf.Tensor1D` filled with the numbers in the range provided. + * + * The tensor is a is half-open interval meaning it includes start, but + * excludes stop. Decrementing ranges and negative step values are also + * supported. + * + * ```js + * tf.range(0, 9, 2).print(); + * ``` + * + * @param start An integer start value + * @param stop An integer stop value + * @param step An integer increment (will default to 1 or -1) + * @param dtype The data type of the output tensor. Defaults to 'float32'. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function range(start, stop, step, dtype) { + if (step === void 0) { step = 1; } + if (dtype === void 0) { dtype = 'float32'; } + if (step === 0) { + throw new Error('Cannot have a step of zero'); + } + var sameStartStop = start === stop; + var increasingRangeNegativeStep = start < stop && step < 0; + var decreasingRangePositiveStep = stop < start && step > 1; + if (sameStartStop || increasingRangeNegativeStep || + decreasingRangePositiveStep) { + return zeros([0], dtype); + } + var numElements = Math.abs(Math.ceil((stop - start) / step)); + var values = makeZerosTypedArray(numElements, dtype); + if (stop < start && step === 1) { + // Auto adjust the step's sign if it hasn't been set + // (or was set to 1) + step = -1; + } + values[0] = start; + for (var i = 1; i < values.length; i++) { + values[i] = values[i - 1] + step; + } + return tensor1d(values, dtype); + } + var onesLike = op({ onesLike_: onesLike_ }); + var zerosLike = op({ zerosLike_: zerosLike_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Concatenates a list of`tf.Tensor1D`s along an axis. See `concat` for details. + * + * For example, if: + * A: shape(3) = |r1, g1, b1| + * B: shape(2) = |r2, g2| + * C = tf.concat1d([A, B]) == |r1, g1, b1, r2, g2| + * + * @param tensors A list of`tf.Tensor`s to concatenate. + * @return The concatenated array. + */ + function concat1d_(tensors) { + return concat(tensors, 0 /* axis */); + } + /** + * Concatenates a list of`tf.Tensor2D`s along an axis. See `concat` for details. + * + * For example, if: + * A: shape(2, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * + * B: shape(2, 3) = | r3, g3, b3 | + * | r4, g4, b4 | + * + * C = tf.concat2d([A, B], axis) + * + * if axis = 0: + * C: shape(4, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * | r3, g3, b3 | + * | r4, g4, b4 | + * + * if axis = 1: + * C = shape(2, 6) = | r1, g1, b1, r3, g3, b3 | + * | r2, g2, b2, r4, g4, b4 | + * + * + * @param tensors A list of `tf.Tensor`s to concatenate. + * @param axis The axis to concatenate along. + * @return The concatenated array. + */ + function concat2d_(tensors, axis) { + return concat(tensors, axis); + } + /** + * Concatenates a list of `tf.Tensor3D`s along an axis. + * See `concat` for details. + * + * For example, if: + * A: shape(2, 1, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * + * B: shape(2, 1, 3) = | r3, g3, b3 | + * | r4, g4, b4 | + * + * C = tf.concat3d([A, B], axis) + * + * if axis = 0: + * C: shape(4, 1, 3) = | r1, g1, b1 | + * | r2, g2, b2 | + * | r3, g3, b3 | + * | r4, g4, b4 | + * + * if axis = 1: + * C: shape(2, 2, 3) = | r1, g1, b1, r3, g3, b3 | + * | r2, g2, b2, r4, g4, b4 | + * + * if axis = 2: + * C = shape(2, 1, 6) = | r1, g1, b1, r3, g3, b3 | + * | r2, g2, b2, r4, g4, b4 | + * + * @param tensors A list of`tf.Tensor`s to concatenate. + * @param axis The axis to concate along. + * @return The concatenated array. + */ + function concat3d_(tensors, axis) { + return concat(tensors, axis); + } + /** + * Concatenates a list of `tf.Tensor4D`s along an axis. + * See `concat` for details. + * + * @param tensors A list of `tf.Tensor`s to concatenate. + * @param axis The axis to concate along. + * @return The concatenated array. + */ + function concat4d_(tensors, axis) { + return concat(tensors, axis); + } + /** + * Concatenates a list of `tf.Tensor`s along a given axis. + * + * The tensors ranks and types must match, and their sizes must match in all + * dimensions except `axis`. + * + * Also available are stricter rank-specific methods that assert that + * `tensors` are of the given rank: + * - `tf.concat1d` + * - `tf.concat2d` + * - `tf.concat3d` + * - `tf.concat4d` + * + * Except `tf.concat1d` (which does not have axis param), all methods have + * same signature as this method. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * a.concat(b).print(); // or a.concat(b) + * ``` + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * const c = tf.tensor1d([5, 6]); + * tf.concat([a, b, c]).print(); + * ``` + * + * ```js + * const a = tf.tensor2d([[1, 2], [10, 20]]); + * const b = tf.tensor2d([[3, 4], [30, 40]]); + * const axis = 1; + * tf.concat([a, b], axis).print(); + * ``` + * @param tensors A list of tensors to concatenate. + * @param axis The axis to concate along. Defaults to 0 (the first dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function concat_(tensors, axis) { + if (axis === void 0) { axis = 0; } + assert(tensors.length >= 1, function () { return 'Pass at least one tensor to concat'; }); + var $tensors = convertToTensorArray(tensors, 'tensors', 'concat'); + if ($tensors[0].dtype === 'complex64') { + $tensors.forEach(function (tensor) { + if (tensor.dtype !== 'complex64') { + throw new Error("Cannot concatenate complex64 tensors with a tensor\n with dtype " + tensor.dtype + ". "); + } + }); + } + axis = parseAxisParam(axis, $tensors[0].shape)[0]; + var outShape = computeOutShape($tensors.map(function (t) { return t.shape; }), axis); + if (sizeFromShape(outShape) === 0) { + return tensor([], outShape); + } + // Keep only non-empty tensors (ignore tensors with 0 in their shape). + $tensors = $tensors.filter(function (t) { return t.size > 0; }); + if ($tensors.length === 1) { + return $tensors[0]; + } + var shapes = $tensors.map(function (t) { return t.shape; }); + assertParamsConsistent(shapes, axis); + var der = function (dy) { + var sizeSplits = shapes.map(function (s) { return s[axis]; }); + var derTensors = split(dy, sizeSplits, axis); + return derTensors.map(function (t) { return function () { return t; }; }); + }; + var inputs = $tensors; + var attr = { axis: axis }; + return ENGINE.runKernelFunc(function (backend) { return backend.concat($tensors, axis); }, inputs, der, 'Concat', attr); + } + /** + * Splits a `tf.Tensor` into sub tensors. + * + * If `numOrSizeSplits` is a number, splits `x` along dimension `axis` + * into `numOrSizeSplits` smaller tensors. + * Requires that `numOrSizeSplits` evenly divides `x.shape[axis]`. + * + * If `numOrSizeSplits` is a number array, splits `x` into + * `numOrSizeSplits.length` pieces. The shape of the `i`-th piece has the + * same size as `x` except along dimension `axis` where the size is + * `numOrSizeSplits[i]`. + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4, 5, 6, 7, 8], [2, 4]); + * const [a, b] = tf.split(x, 2, 1); + * a.print(); + * b.print(); + * + * const [c, d, e] = tf.split(x, [1, 2, 1], 1); + * c.print(); + * d.print(); + * e.print(); + * ``` + * + * @param x The input tensor to split. + * @param numOrSizeSplits Either an integer indicating the number of + * splits along the axis or an array of integers containing the sizes of + * each output tensor along the axis. If a number then it must evenly divide + * `x.shape[axis]`; otherwise the sum of sizes must match `x.shape[axis]`. + * @param axis The dimension along which to split. Defaults to 0 (the first + * dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function split_(x, numOrSizeSplits, axis) { + if (axis === void 0) { axis = 0; } + var $x = convertToTensor(x, 'x', 'split'); + axis = parseAxisParam(axis, $x.shape)[0]; + var splitSizes; + if (typeof (numOrSizeSplits) === 'number') { + assert($x.shape[axis] % numOrSizeSplits === 0, function () { return 'Number of splits must evenly divide the axis.'; }); + splitSizes = + new Array(numOrSizeSplits).fill($x.shape[axis] / numOrSizeSplits); + } + else { + assert($x.shape[axis] === numOrSizeSplits.reduce(function (a, b) { return a + b; }), function () { return 'The sum of sizes must match the size of the axis dimension.'; }); + splitSizes = numOrSizeSplits; + } + var der = function (dy) { return ({ $x: function () { return concat(dy, axis); } }); }; + return ENGINE.runKernelFunc(function (backend) { return backend.split($x, splitSizes, axis); }, { $x: $x }, der); + } + var concat = op({ concat_: concat_ }); + var concat1d = op({ concat1d_: concat1d_ }); + var concat2d = op({ concat2d_: concat2d_ }); + var concat3d = op({ concat3d_: concat3d_ }); + var concat4d = op({ concat4d_: concat4d_ }); + var split = op({ split_: split_ }); + + var commonjsGlobal = typeof globalThis !== 'undefined' ? globalThis : typeof window !== 'undefined' ? window : typeof global !== 'undefined' ? global : typeof self !== 'undefined' ? self : {}; + + function createCommonjsModule(fn, module) { + return module = { exports: {} }, fn(module, module.exports), module.exports; + } + + var alea = createCommonjsModule(function (module) { + // A port of an algorithm by Johannes Baagøe , 2010 + // http://baagoe.com/en/RandomMusings/javascript/ + // https://github.com/nquinlan/better-random-numbers-for-javascript-mirror + // Original work is under MIT license - + + // Copyright (C) 2010 by Johannes Baagøe + // + // Permission is hereby granted, free of charge, to any person obtaining a copy + // of this software and associated documentation files (the "Software"), to deal + // in the Software without restriction, including without limitation the rights + // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + // copies of the Software, and to permit persons to whom the Software is + // furnished to do so, subject to the following conditions: + // + // The above copyright notice and this permission notice shall be included in + // all copies or substantial portions of the Software. + // + // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + // THE SOFTWARE. + + + + (function(global, module, define) { + + function Alea(seed) { + var me = this, mash = Mash(); + + me.next = function() { + var t = 2091639 * me.s0 + me.c * 2.3283064365386963e-10; // 2^-32 + me.s0 = me.s1; + me.s1 = me.s2; + return me.s2 = t - (me.c = t | 0); + }; + + // Apply the seeding algorithm from Baagoe. + me.c = 1; + me.s0 = mash(' '); + me.s1 = mash(' '); + me.s2 = mash(' '); + me.s0 -= mash(seed); + if (me.s0 < 0) { me.s0 += 1; } + me.s1 -= mash(seed); + if (me.s1 < 0) { me.s1 += 1; } + me.s2 -= mash(seed); + if (me.s2 < 0) { me.s2 += 1; } + mash = null; + } + + function copy(f, t) { + t.c = f.c; + t.s0 = f.s0; + t.s1 = f.s1; + t.s2 = f.s2; + return t; + } + + function impl(seed, opts) { + var xg = new Alea(seed), + state = opts && opts.state, + prng = xg.next; + prng.int32 = function() { return (xg.next() * 0x100000000) | 0; }; + prng.double = function() { + return prng() + (prng() * 0x200000 | 0) * 1.1102230246251565e-16; // 2^-53 + }; + prng.quick = prng; + if (state) { + if (typeof(state) == 'object') copy(state, xg); + prng.state = function() { return copy(xg, {}); }; + } + return prng; + } + + function Mash() { + var n = 0xefc8249d; + + var mash = function(data) { + data = data.toString(); + for (var i = 0; i < data.length; i++) { + n += data.charCodeAt(i); + var h = 0.02519603282416938 * n; + n = h >>> 0; + h -= n; + h *= n; + n = h >>> 0; + h -= n; + n += h * 0x100000000; // 2^32 + } + return (n >>> 0) * 2.3283064365386963e-10; // 2^-32 + }; + + return mash; + } + + + if (module && module.exports) { + module.exports = impl; + } else if (define && define.amd) { + define(function() { return impl; }); + } else { + this.alea = impl; + } + + })( + commonjsGlobal, + module, // present in node.js + (typeof undefined) == 'function' // present with an AMD loader + ); + }); + + var xor128 = createCommonjsModule(function (module) { + // A Javascript implementaion of the "xor128" prng algorithm by + // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper + + (function(global, module, define) { + + function XorGen(seed) { + var me = this, strseed = ''; + + me.x = 0; + me.y = 0; + me.z = 0; + me.w = 0; + + // Set up generator function. + me.next = function() { + var t = me.x ^ (me.x << 11); + me.x = me.y; + me.y = me.z; + me.z = me.w; + return me.w ^= (me.w >>> 19) ^ t ^ (t >>> 8); + }; + + if (seed === (seed | 0)) { + // Integer seed. + me.x = seed; + } else { + // String seed. + strseed += seed; + } + + // Mix in string seed, then discard an initial batch of 64 values. + for (var k = 0; k < strseed.length + 64; k++) { + me.x ^= strseed.charCodeAt(k) | 0; + me.next(); + } + } + + function copy(f, t) { + t.x = f.x; + t.y = f.y; + t.z = f.z; + t.w = f.w; + return t; + } + + function impl(seed, opts) { + var xg = new XorGen(seed), + state = opts && opts.state, + prng = function() { return (xg.next() >>> 0) / 0x100000000; }; + prng.double = function() { + do { + var top = xg.next() >>> 11, + bot = (xg.next() >>> 0) / 0x100000000, + result = (top + bot) / (1 << 21); + } while (result === 0); + return result; + }; + prng.int32 = xg.next; + prng.quick = prng; + if (state) { + if (typeof(state) == 'object') copy(state, xg); + prng.state = function() { return copy(xg, {}); }; + } + return prng; + } + + if (module && module.exports) { + module.exports = impl; + } else if (define && define.amd) { + define(function() { return impl; }); + } else { + this.xor128 = impl; + } + + })( + commonjsGlobal, + module, // present in node.js + (typeof undefined) == 'function' // present with an AMD loader + ); + }); + + var xorwow = createCommonjsModule(function (module) { + // A Javascript implementaion of the "xorwow" prng algorithm by + // George Marsaglia. See http://www.jstatsoft.org/v08/i14/paper + + (function(global, module, define) { + + function XorGen(seed) { + var me = this, strseed = ''; + + // Set up generator function. + me.next = function() { + var t = (me.x ^ (me.x >>> 2)); + me.x = me.y; me.y = me.z; me.z = me.w; me.w = me.v; + return (me.d = (me.d + 362437 | 0)) + + (me.v = (me.v ^ (me.v << 4)) ^ (t ^ (t << 1))) | 0; + }; + + me.x = 0; + me.y = 0; + me.z = 0; + me.w = 0; + me.v = 0; + + if (seed === (seed | 0)) { + // Integer seed. + me.x = seed; + } else { + // String seed. + strseed += seed; + } + + // Mix in string seed, then discard an initial batch of 64 values. + for (var k = 0; k < strseed.length + 64; k++) { + me.x ^= strseed.charCodeAt(k) | 0; + if (k == strseed.length) { + me.d = me.x << 10 ^ me.x >>> 4; + } + me.next(); + } + } + + function copy(f, t) { + t.x = f.x; + t.y = f.y; + t.z = f.z; + t.w = f.w; + t.v = f.v; + t.d = f.d; + return t; + } + + function impl(seed, opts) { + var xg = new XorGen(seed), + state = opts && opts.state, + prng = function() { return (xg.next() >>> 0) / 0x100000000; }; + prng.double = function() { + do { + var top = xg.next() >>> 11, + bot = (xg.next() >>> 0) / 0x100000000, + result = (top + bot) / (1 << 21); + } while (result === 0); + return result; + }; + prng.int32 = xg.next; + prng.quick = prng; + if (state) { + if (typeof(state) == 'object') copy(state, xg); + prng.state = function() { return copy(xg, {}); }; + } + return prng; + } + + if (module && module.exports) { + module.exports = impl; + } else if (define && define.amd) { + define(function() { return impl; }); + } else { + this.xorwow = impl; + } + + })( + commonjsGlobal, + module, // present in node.js + (typeof undefined) == 'function' // present with an AMD loader + ); + }); + + var xorshift7 = createCommonjsModule(function (module) { + // A Javascript implementaion of the "xorshift7" algorithm by + // François Panneton and Pierre L'ecuyer: + // "On the Xorgshift Random Number Generators" + // http://saluc.engr.uconn.edu/refs/crypto/rng/panneton05onthexorshift.pdf + + (function(global, module, define) { + + function XorGen(seed) { + var me = this; + + // Set up generator function. + me.next = function() { + // Update xor generator. + var X = me.x, i = me.i, t, v; + t = X[i]; t ^= (t >>> 7); v = t ^ (t << 24); + t = X[(i + 1) & 7]; v ^= t ^ (t >>> 10); + t = X[(i + 3) & 7]; v ^= t ^ (t >>> 3); + t = X[(i + 4) & 7]; v ^= t ^ (t << 7); + t = X[(i + 7) & 7]; t = t ^ (t << 13); v ^= t ^ (t << 9); + X[i] = v; + me.i = (i + 1) & 7; + return v; + }; + + function init(me, seed) { + var j, w, X = []; + + if (seed === (seed | 0)) { + // Seed state array using a 32-bit integer. + w = X[0] = seed; + } else { + // Seed state using a string. + seed = '' + seed; + for (j = 0; j < seed.length; ++j) { + X[j & 7] = (X[j & 7] << 15) ^ + (seed.charCodeAt(j) + X[(j + 1) & 7] << 13); + } + } + // Enforce an array length of 8, not all zeroes. + while (X.length < 8) X.push(0); + for (j = 0; j < 8 && X[j] === 0; ++j); + if (j == 8) w = X[7] = -1; else w = X[j]; + + me.x = X; + me.i = 0; + + // Discard an initial 256 values. + for (j = 256; j > 0; --j) { + me.next(); + } + } + + init(me, seed); + } + + function copy(f, t) { + t.x = f.x.slice(); + t.i = f.i; + return t; + } + + function impl(seed, opts) { + if (seed == null) seed = +(new Date); + var xg = new XorGen(seed), + state = opts && opts.state, + prng = function() { return (xg.next() >>> 0) / 0x100000000; }; + prng.double = function() { + do { + var top = xg.next() >>> 11, + bot = (xg.next() >>> 0) / 0x100000000, + result = (top + bot) / (1 << 21); + } while (result === 0); + return result; + }; + prng.int32 = xg.next; + prng.quick = prng; + if (state) { + if (state.x) copy(state, xg); + prng.state = function() { return copy(xg, {}); }; + } + return prng; + } + + if (module && module.exports) { + module.exports = impl; + } else if (define && define.amd) { + define(function() { return impl; }); + } else { + this.xorshift7 = impl; + } + + })( + commonjsGlobal, + module, // present in node.js + (typeof undefined) == 'function' // present with an AMD loader + ); + }); + + var xor4096 = createCommonjsModule(function (module) { + // A Javascript implementaion of Richard Brent's Xorgens xor4096 algorithm. + // + // This fast non-cryptographic random number generator is designed for + // use in Monte-Carlo algorithms. It combines a long-period xorshift + // generator with a Weyl generator, and it passes all common batteries + // of stasticial tests for randomness while consuming only a few nanoseconds + // for each prng generated. For background on the generator, see Brent's + // paper: "Some long-period random number generators using shifts and xors." + // http://arxiv.org/pdf/1004.3115v1.pdf + // + // Usage: + // + // var xor4096 = require('xor4096'); + // random = xor4096(1); // Seed with int32 or string. + // assert.equal(random(), 0.1520436450538547); // (0, 1) range, 53 bits. + // assert.equal(random.int32(), 1806534897); // signed int32, 32 bits. + // + // For nonzero numeric keys, this impelementation provides a sequence + // identical to that by Brent's xorgens 3 implementaion in C. This + // implementation also provides for initalizing the generator with + // string seeds, or for saving and restoring the state of the generator. + // + // On Chrome, this prng benchmarks about 2.1 times slower than + // Javascript's built-in Math.random(). + + (function(global, module, define) { + + function XorGen(seed) { + var me = this; + + // Set up generator function. + me.next = function() { + var w = me.w, + X = me.X, i = me.i, t, v; + // Update Weyl generator. + me.w = w = (w + 0x61c88647) | 0; + // Update xor generator. + v = X[(i + 34) & 127]; + t = X[i = ((i + 1) & 127)]; + v ^= v << 13; + t ^= t << 17; + v ^= v >>> 15; + t ^= t >>> 12; + // Update Xor generator array state. + v = X[i] = v ^ t; + me.i = i; + // Result is the combination. + return (v + (w ^ (w >>> 16))) | 0; + }; + + function init(me, seed) { + var t, v, i, j, w, X = [], limit = 128; + if (seed === (seed | 0)) { + // Numeric seeds initialize v, which is used to generates X. + v = seed; + seed = null; + } else { + // String seeds are mixed into v and X one character at a time. + seed = seed + '\0'; + v = 0; + limit = Math.max(limit, seed.length); + } + // Initialize circular array and weyl value. + for (i = 0, j = -32; j < limit; ++j) { + // Put the unicode characters into the array, and shuffle them. + if (seed) v ^= seed.charCodeAt((j + 32) % seed.length); + // After 32 shuffles, take v as the starting w value. + if (j === 0) w = v; + v ^= v << 10; + v ^= v >>> 15; + v ^= v << 4; + v ^= v >>> 13; + if (j >= 0) { + w = (w + 0x61c88647) | 0; // Weyl. + t = (X[j & 127] ^= (v + w)); // Combine xor and weyl to init array. + i = (0 == t) ? i + 1 : 0; // Count zeroes. + } + } + // We have detected all zeroes; make the key nonzero. + if (i >= 128) { + X[(seed && seed.length || 0) & 127] = -1; + } + // Run the generator 512 times to further mix the state before using it. + // Factoring this as a function slows the main generator, so it is just + // unrolled here. The weyl generator is not advanced while warming up. + i = 127; + for (j = 4 * 128; j > 0; --j) { + v = X[(i + 34) & 127]; + t = X[i = ((i + 1) & 127)]; + v ^= v << 13; + t ^= t << 17; + v ^= v >>> 15; + t ^= t >>> 12; + X[i] = v ^ t; + } + // Storing state as object members is faster than using closure variables. + me.w = w; + me.X = X; + me.i = i; + } + + init(me, seed); + } + + function copy(f, t) { + t.i = f.i; + t.w = f.w; + t.X = f.X.slice(); + return t; + } + function impl(seed, opts) { + if (seed == null) seed = +(new Date); + var xg = new XorGen(seed), + state = opts && opts.state, + prng = function() { return (xg.next() >>> 0) / 0x100000000; }; + prng.double = function() { + do { + var top = xg.next() >>> 11, + bot = (xg.next() >>> 0) / 0x100000000, + result = (top + bot) / (1 << 21); + } while (result === 0); + return result; + }; + prng.int32 = xg.next; + prng.quick = prng; + if (state) { + if (state.X) copy(state, xg); + prng.state = function() { return copy(xg, {}); }; + } + return prng; + } + + if (module && module.exports) { + module.exports = impl; + } else if (define && define.amd) { + define(function() { return impl; }); + } else { + this.xor4096 = impl; + } + + })( + commonjsGlobal, // window object or global + module, // present in node.js + (typeof undefined) == 'function' // present with an AMD loader + ); + }); + + var tychei = createCommonjsModule(function (module) { + // A Javascript implementaion of the "Tyche-i" prng algorithm by + // Samuel Neves and Filipe Araujo. + // See https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf + + (function(global, module, define) { + + function XorGen(seed) { + var me = this, strseed = ''; + + // Set up generator function. + me.next = function() { + var b = me.b, c = me.c, d = me.d, a = me.a; + b = (b << 25) ^ (b >>> 7) ^ c; + c = (c - d) | 0; + d = (d << 24) ^ (d >>> 8) ^ a; + a = (a - b) | 0; + me.b = b = (b << 20) ^ (b >>> 12) ^ c; + me.c = c = (c - d) | 0; + me.d = (d << 16) ^ (c >>> 16) ^ a; + return me.a = (a - b) | 0; + }; + + /* The following is non-inverted tyche, which has better internal + * bit diffusion, but which is about 25% slower than tyche-i in JS. + me.next = function() { + var a = me.a, b = me.b, c = me.c, d = me.d; + a = (me.a + me.b | 0) >>> 0; + d = me.d ^ a; d = d << 16 ^ d >>> 16; + c = me.c + d | 0; + b = me.b ^ c; b = b << 12 ^ d >>> 20; + me.a = a = a + b | 0; + d = d ^ a; me.d = d = d << 8 ^ d >>> 24; + me.c = c = c + d | 0; + b = b ^ c; + return me.b = (b << 7 ^ b >>> 25); + } + */ + + me.a = 0; + me.b = 0; + me.c = 2654435769 | 0; + me.d = 1367130551; + + if (seed === Math.floor(seed)) { + // Integer seed. + me.a = (seed / 0x100000000) | 0; + me.b = seed | 0; + } else { + // String seed. + strseed += seed; + } + + // Mix in string seed, then discard an initial batch of 64 values. + for (var k = 0; k < strseed.length + 20; k++) { + me.b ^= strseed.charCodeAt(k) | 0; + me.next(); + } + } + + function copy(f, t) { + t.a = f.a; + t.b = f.b; + t.c = f.c; + t.d = f.d; + return t; + } + function impl(seed, opts) { + var xg = new XorGen(seed), + state = opts && opts.state, + prng = function() { return (xg.next() >>> 0) / 0x100000000; }; + prng.double = function() { + do { + var top = xg.next() >>> 11, + bot = (xg.next() >>> 0) / 0x100000000, + result = (top + bot) / (1 << 21); + } while (result === 0); + return result; + }; + prng.int32 = xg.next; + prng.quick = prng; + if (state) { + if (typeof(state) == 'object') copy(state, xg); + prng.state = function() { return copy(xg, {}); }; + } + return prng; + } + + if (module && module.exports) { + module.exports = impl; + } else if (define && define.amd) { + define(function() { return impl; }); + } else { + this.tychei = impl; + } + + })( + commonjsGlobal, + module, // present in node.js + (typeof undefined) == 'function' // present with an AMD loader + ); + }); + + var seedrandom = createCommonjsModule(function (module) { + /* + Copyright 2014 David Bau. + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + */ + + (function (pool, math) { + // + // The following constants are related to IEEE 754 limits. + // + var global = this, + width = 256, // each RC4 output is 0 <= x < 256 + chunks = 6, // at least six RC4 outputs for each double + digits = 52, // there are 52 significant digits in a double + rngname = 'random', // rngname: name for Math.random and Math.seedrandom + startdenom = math.pow(width, chunks), + significance = math.pow(2, digits), + overflow = significance * 2, + mask = width - 1, + nodecrypto; // node.js crypto module, initialized at the bottom. + + // + // seedrandom() + // This is the seedrandom function described above. + // + function seedrandom(seed, options, callback) { + var key = []; + options = (options == true) ? { entropy: true } : (options || {}); + + // Flatten the seed string or build one from local entropy if needed. + var shortseed = mixkey(flatten( + options.entropy ? [seed, tostring(pool)] : + (seed == null) ? autoseed() : seed, 3), key); + + // Use the seed to initialize an ARC4 generator. + var arc4 = new ARC4(key); + + // This function returns a random double in [0, 1) that contains + // randomness in every bit of the mantissa of the IEEE 754 value. + var prng = function() { + var n = arc4.g(chunks), // Start with a numerator n < 2 ^ 48 + d = startdenom, // and denominator d = 2 ^ 48. + x = 0; // and no 'extra last byte'. + while (n < significance) { // Fill up all significant digits by + n = (n + x) * width; // shifting numerator and + d *= width; // denominator and generating a + x = arc4.g(1); // new least-significant-byte. + } + while (n >= overflow) { // To avoid rounding up, before adding + n /= 2; // last byte, shift everything + d /= 2; // right using integer math until + x >>>= 1; // we have exactly the desired bits. + } + return (n + x) / d; // Form the number within [0, 1). + }; + + prng.int32 = function() { return arc4.g(4) | 0; }; + prng.quick = function() { return arc4.g(4) / 0x100000000; }; + prng.double = prng; + + // Mix the randomness into accumulated entropy. + mixkey(tostring(arc4.S), pool); + + // Calling convention: what to return as a function of prng, seed, is_math. + return (options.pass || callback || + function(prng, seed, is_math_call, state) { + if (state) { + // Load the arc4 state from the given state if it has an S array. + if (state.S) { copy(state, arc4); } + // Only provide the .state method if requested via options.state. + prng.state = function() { return copy(arc4, {}); }; + } + + // If called as a method of Math (Math.seedrandom()), mutate + // Math.random because that is how seedrandom.js has worked since v1.0. + if (is_math_call) { math[rngname] = prng; return seed; } + + // Otherwise, it is a newer calling convention, so return the + // prng directly. + else return prng; + })( + prng, + shortseed, + 'global' in options ? options.global : (this == math), + options.state); + } + math['seed' + rngname] = seedrandom; + + // + // ARC4 + // + // An ARC4 implementation. The constructor takes a key in the form of + // an array of at most (width) integers that should be 0 <= x < (width). + // + // The g(count) method returns a pseudorandom integer that concatenates + // the next (count) outputs from ARC4. Its return value is a number x + // that is in the range 0 <= x < (width ^ count). + // + function ARC4(key) { + var t, keylen = key.length, + me = this, i = 0, j = me.i = me.j = 0, s = me.S = []; + + // The empty key [] is treated as [0]. + if (!keylen) { key = [keylen++]; } + + // Set up S using the standard key scheduling algorithm. + while (i < width) { + s[i] = i++; + } + for (i = 0; i < width; i++) { + s[i] = s[j = mask & (j + key[i % keylen] + (t = s[i]))]; + s[j] = t; + } + + // The "g" method returns the next (count) outputs as one number. + (me.g = function(count) { + // Using instance members instead of closure state nearly doubles speed. + var t, r = 0, + i = me.i, j = me.j, s = me.S; + while (count--) { + t = s[i = mask & (i + 1)]; + r = r * width + s[mask & ((s[i] = s[j = mask & (j + t)]) + (s[j] = t))]; + } + me.i = i; me.j = j; + return r; + // For robust unpredictability, the function call below automatically + // discards an initial batch of values. This is called RC4-drop[256]. + // See http://google.com/search?q=rsa+fluhrer+response&btnI + })(width); + } + + // + // copy() + // Copies internal state of ARC4 to or from a plain object. + // + function copy(f, t) { + t.i = f.i; + t.j = f.j; + t.S = f.S.slice(); + return t; + } + // + // flatten() + // Converts an object tree to nested arrays of strings. + // + function flatten(obj, depth) { + var result = [], typ = (typeof obj), prop; + if (depth && typ == 'object') { + for (prop in obj) { + try { result.push(flatten(obj[prop], depth - 1)); } catch (e) {} + } + } + return (result.length ? result : typ == 'string' ? obj : obj + '\0'); + } + + // + // mixkey() + // Mixes a string seed into a key that is an array of integers, and + // returns a shortened string seed that is equivalent to the result key. + // + function mixkey(seed, key) { + var stringseed = seed + '', smear, j = 0; + while (j < stringseed.length) { + key[mask & j] = + mask & ((smear ^= key[mask & j] * 19) + stringseed.charCodeAt(j++)); + } + return tostring(key); + } + + // + // autoseed() + // Returns an object for autoseeding, using window.crypto and Node crypto + // module if available. + // + function autoseed() { + try { + var out; + if (nodecrypto && (out = nodecrypto.randomBytes)) { + // The use of 'out' to remember randomBytes makes tight minified code. + out = out(width); + } else { + out = new Uint8Array(width); + (global.crypto || global.msCrypto).getRandomValues(out); + } + return tostring(out); + } catch (e) { + var browser = global.navigator, + plugins = browser && browser.plugins; + return [+new Date, global, plugins, global.screen, tostring(pool)]; + } + } + + // + // tostring() + // Converts an array of charcodes to a string + // + function tostring(a) { + return String.fromCharCode.apply(0, a); + } + + // + // When seedrandom.js is loaded, we immediately mix a few bits + // from the built-in RNG into the entropy pool. Because we do + // not want to interfere with deterministic PRNG state later, + // seedrandom will not call math.random on its own again after + // initialization. + // + mixkey(math.random(), pool); + + // + // Nodejs and AMD support: export the implementation as a module using + // either convention. + // + if (module.exports) { + module.exports = seedrandom; + // When in node.js, try using crypto package for autoseeding. + try { + nodecrypto = require('crypto'); + } catch (ex) {} + } + + // End anonymous scope, and pass initial values. + })( + [], // pool: entropy pool starts empty + Math // math: package containing random, pow, and seedrandom + ); + }); + + // A library of seedable RNGs implemented in Javascript. + // + // Usage: + // + // var seedrandom = require('seedrandom'); + // var random = seedrandom(1); // or any seed. + // var x = random(); // 0 <= x < 1. Every bit is random. + // var x = random.quick(); // 0 <= x < 1. 32 bits of randomness. + + // alea, a 53-bit multiply-with-carry generator by Johannes Baagøe. + // Period: ~2^116 + // Reported to pass all BigCrush tests. + + + // xor128, a pure xor-shift generator by George Marsaglia. + // Period: 2^128-1. + // Reported to fail: MatrixRank and LinearComp. + + + // xorwow, George Marsaglia's 160-bit xor-shift combined plus weyl. + // Period: 2^192-2^32 + // Reported to fail: CollisionOver, SimpPoker, and LinearComp. + + + // xorshift7, by François Panneton and Pierre L'ecuyer, takes + // a different approach: it adds robustness by allowing more shifts + // than Marsaglia's original three. It is a 7-shift generator + // with 256 bits, that passes BigCrush with no systmatic failures. + // Period 2^256-1. + // No systematic BigCrush failures reported. + + + // xor4096, by Richard Brent, is a 4096-bit xor-shift with a + // very long period that also adds a Weyl generator. It also passes + // BigCrush with no systematic failures. Its long period may + // be useful if you have many generators and need to avoid + // collisions. + // Period: 2^4128-2^32. + // No systematic BigCrush failures reported. + + + // Tyche-i, by Samuel Neves and Filipe Araujo, is a bit-shifting random + // number generator derived from ChaCha, a modern stream cipher. + // https://eden.dei.uc.pt/~sneves/pubs/2011-snfa2.pdf + // Period: ~2^127 + // No systematic BigCrush failures reported. + + + // The original ARC4-based prng included in this library. + // Period: ~2^1600 + + + seedrandom.alea = alea; + seedrandom.xor128 = xor128; + seedrandom.xorwow = xorwow; + seedrandom.xorshift7 = xorshift7; + seedrandom.xor4096 = xor4096; + seedrandom.tychei = tychei; + + var seedrandom$1 = seedrandom; + var seedrandom_1 = seedrandom$1.alea; + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // https://en.wikipedia.org/wiki/Marsaglia_polar_method + var MPRandGauss = /** @class */ (function () { + function MPRandGauss(mean, stdDeviation, dtype, truncated, seed) { + this.mean = mean; + this.stdDev = stdDeviation; + this.dtype = dtype; + this.nextVal = NaN; + this.truncated = truncated; + if (this.truncated) { + this.upper = this.mean + this.stdDev * 2; + this.lower = this.mean - this.stdDev * 2; + } + var seedValue = seed ? seed : Math.random(); + this.random = seedrandom_1(seedValue.toString()); + } + /** Returns next sample from a Gaussian distribution. */ + MPRandGauss.prototype.nextValue = function () { + if (!isNaN(this.nextVal)) { + var value = this.nextVal; + this.nextVal = NaN; + return value; + } + var resultX, resultY; + var isValid = false; + while (!isValid) { + var v1 = void 0, v2 = void 0, s = void 0; + do { + v1 = 2 * this.random() - 1; + v2 = 2 * this.random() - 1; + s = v1 * v1 + v2 * v2; + } while (s >= 1 || s === 0); + var mul = Math.sqrt(-2.0 * Math.log(s) / s); + resultX = this.mean + this.stdDev * v1 * mul; + resultY = this.mean + this.stdDev * v2 * mul; + if (!this.truncated || this.isValidTruncated(resultX)) { + isValid = true; + } + } + if (!this.truncated || this.isValidTruncated(resultY)) { + this.nextVal = this.convertValue(resultY); + } + return this.convertValue(resultX); + }; + /** Handles proper rounding for non-floating-point numbers. */ + MPRandGauss.prototype.convertValue = function (value) { + if (this.dtype == null || this.dtype === 'float32') { + return value; + } + return Math.round(value); + }; + /** Returns true if less than 2-standard-deviations from the mean. */ + MPRandGauss.prototype.isValidTruncated = function (value) { + return value <= this.upper && value >= this.lower; + }; + return MPRandGauss; + }()); + // Marsaglia, George, and Wai Wan Tsang. 2000. "A Simple Method for Generating + // Gamma Variables." + var RandGamma = /** @class */ (function () { + function RandGamma(alpha, beta, dtype, seed) { + this.alpha = alpha; + this.beta = 1 / beta; // convert rate to scale parameter + this.dtype = dtype; + var seedValue = seed ? seed : Math.random(); + this.randu = seedrandom_1(seedValue.toString()); + this.randn = new MPRandGauss(0, 1, dtype, false, this.randu()); + if (alpha < 1) { + this.d = alpha + (2 / 3); + } + else { + this.d = alpha - (1 / 3); + } + this.c = 1 / Math.sqrt(9 * this.d); + } + /** Returns next sample from a gamma distribution. */ + RandGamma.prototype.nextValue = function () { + var x2, v0, v1, x, u, v; + while (true) { + do { + x = this.randn.nextValue(); + v = 1 + (this.c * x); + } while (v <= 0); + v *= v * v; + x2 = x * x; + v0 = 1 - (0.331 * x2 * x2); + v1 = (0.5 * x2) + (this.d * (1 - v + Math.log(v))); + u = this.randu(); + if (u < v0 || Math.log(u) < v1) { + break; + } + } + v = (1 / this.beta) * this.d * v; + if (this.alpha < 1) { + v *= Math.pow(this.randu(), 1 / this.alpha); + } + return this.convertValue(v); + }; + /** Handles proper rounding for non-floating-point numbers. */ + RandGamma.prototype.convertValue = function (value) { + if (this.dtype === 'float32') { + return value; + } + return Math.round(value); + }; + return RandGamma; + }()); + var UniformRandom = /** @class */ (function () { + function UniformRandom(min, max, dtype, seed) { + var _this = this; + if (min === void 0) { min = 0; } + if (max === void 0) { max = 1; } + /** Handles proper rounding for non floating point numbers. */ + this.canReturnFloat = function () { + return (_this.dtype == null || _this.dtype === 'float32'); + }; + this.min = min; + this.range = max - min; + this.dtype = dtype; + if (seed == null) { + seed = Math.random(); + } + if (typeof seed === 'number') { + seed = seed.toString(); + } + if (!this.canReturnFloat() && this.range <= 1) { + throw new Error("The difference between " + min + " - " + max + " <= 1 and dtype is not float"); + } + this.random = seedrandom_1(seed); + } + UniformRandom.prototype.convertValue = function (value) { + if (this.canReturnFloat()) { + return value; + } + return Math.round(value); + }; + UniformRandom.prototype.nextValue = function () { + return this.convertValue(this.min + this.range * this.random()); + }; + return UniformRandom; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Creates a new tensor with the same values and shape as the specified + * tensor. + * + * ```js + * const x = tf.tensor([1, 2]); + * + * x.clone().print(); + * ``` + * + * @param x The tensor to clone. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function clone_(x) { + var $x = convertToTensor(x, 'x', 'clone', null); + var der = function (dy) { + return { $x: function () { return dy.toFloat(); } }; + }; + return ENGINE.runKernelFunc(function () { return ENGINE.makeTensorFromDataId($x.dataId, $x.shape, $x.dtype); }, { $x: $x }, der); + } + /** + * Create an identity matrix. + * + * @param numRows Number of rows. + * @param numColumns Number of columns. Defaults to `numRows`. + * @param batchShape If provided, will add the batch shape to the beginning + * of the shape of the returned `tf.Tensor` by repeating the identity + * matrix. + * @param dtype Data type. + * @returns Identity matrix of the specified size and data type, possibly + * with batch repetition if `batchShape` is specified. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function eye_(numRows, numColumns, batchShape, dtype) { + if (dtype === void 0) { dtype = 'float32'; } + if (numColumns == null) { + numColumns = numRows; + } + var buff = buffer([numRows, numColumns], dtype); + var n = numRows <= numColumns ? numRows : numColumns; + for (var i = 0; i < n; ++i) { + buff.set(1, i, i); + } + var out = buff.toTensor().as2D(numRows, numColumns); + if (batchShape == null) { + return out; + } + else { + if (batchShape.length === 1) { + return tile(expandDims(out, 0), [batchShape[0], 1, 1]); + } + else if (batchShape.length === 2) { + return tile(expandDims(expandDims(out, 0), 0), [batchShape[0], batchShape[1], 1, 1]); + } + else if (batchShape.length === 3) { + return tile(expandDims(expandDims(expandDims(out, 0), 0), 0), [batchShape[0], batchShape[1], batchShape[2], 1, 1]); + } + else { + throw new Error("eye() currently supports only 1D and 2D " + + ( + // tslint:disable-next-line:no-any + "batchShapes, but received " + batchShape.length + "D.")); + } + } + } + /** + * Creates a `tf.Tensor` with values sampled from a normal distribution. + * + * ```js + * tf.randomNormal([2, 2]).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param mean The mean of the normal distribution. + * @param stdDev The standard deviation of the normal distribution. + * @param dtype The data type of the output. + * @param seed The seed for the random number generator. + */ + /** @doc {heading: 'Tensors', subheading: 'Random'} */ + function randomNormal_(shape, mean, stdDev, dtype, seed) { + if (mean === void 0) { mean = 0; } + if (stdDev === void 0) { stdDev = 1; } + if (dtype != null && dtype === 'bool') { + throw new Error("Unsupported data type " + dtype); + } + var randGauss = new MPRandGauss(mean, stdDev, dtype, false /* truncated */, seed); + var res = buffer(shape, dtype); + for (var i = 0; i < res.values.length; i++) { + res.values[i] = randGauss.nextValue(); + } + return res.toTensor(); + } + /** + * Creates a `tf.Tensor` with values sampled from a truncated normal + * distribution. + * + * ```js + * tf.truncatedNormal([2, 2]).print(); + * ``` + * + * The generated values follow a normal distribution with specified mean and + * standard deviation, except that values whose magnitude is more than 2 + * standard deviations from the mean are dropped and re-picked. + * + * @param shape An array of integers defining the output tensor shape. + * @param mean The mean of the normal distribution. + * @param stdDev The standard deviation of the normal distribution. + * @param dtype The data type of the output tensor. + * @param seed The seed for the random number generator. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function truncatedNormal_(shape, mean, stdDev, dtype, seed) { + if (mean === void 0) { mean = 0; } + if (stdDev === void 0) { stdDev = 1; } + if (dtype != null && dtype === 'bool') { + throw new Error("Unsupported data type " + dtype); + } + var randGauss = new MPRandGauss(mean, stdDev, dtype, true /* truncated */, seed); + var res = buffer(shape, dtype); + for (var i = 0; i < res.values.length; i++) { + res.values[i] = randGauss.nextValue(); + } + return res.toTensor(); + } + /** + * Creates a `tf.Tensor` with values sampled from a gamma distribution. + * + * ```js + * tf.randomGamma([2, 2], 1).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param alpha The shape parameter of the gamma distribution. + * @param beta The inverse scale parameter of the gamma distribution. Defaults + * to 1. + * @param dtype The data type of the output. Defaults to float32. + * @param seed The seed for the random number generator. + */ + /** @doc {heading: 'Tensors', subheading: 'Random'} */ + function randomGamma_(shape, alpha, beta, dtype, seed) { + if (beta === void 0) { beta = 1; } + if (dtype === void 0) { dtype = 'float32'; } + if (beta == null) { + beta = 1; + } + if (dtype == null) { + dtype = 'float32'; + } + if (dtype !== 'float32' && dtype !== 'int32') { + throw new Error("Unsupported data type " + dtype); + } + var rgamma = new RandGamma(alpha, beta, dtype, seed); + var res = buffer(shape, dtype); + for (var i = 0; i < res.values.length; i++) { + res.values[i] = rgamma.nextValue(); + } + return res.toTensor(); + } + /** + * Creates a `tf.Tensor` with values sampled from a uniform distribution. + * + * The generated values follow a uniform distribution in the range [minval, + * maxval). The lower bound minval is included in the range, while the upper + * bound maxval is excluded. + * + * ```js + * tf.randomUniform([2, 2]).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param minval The lower bound on the range of random values to generate. + * Defaults to 0. + * @param maxval The upper bound on the range of random values to generate. + * Defaults to 1. + * @param dtype The data type of the output tensor. Defaults to 'float32'. + */ + /** @doc {heading: 'Tensors', subheading: 'Random'} */ + function randomUniform_(shape, minval, maxval, dtype, seed) { + if (minval === void 0) { minval = 0; } + if (maxval === void 0) { maxval = 1; } + if (dtype === void 0) { dtype = 'float32'; } + var res = buffer(shape, dtype); + var random = new UniformRandom(minval, maxval, null, seed); + for (var i = 0; i < res.values.length; i++) { + res.values[i] = random.nextValue(); + } + return res.toTensor(); + } + /** + * Creates a `tf.Tensor` with values sampled from a random number generator + * function defined by the user. + * + * @param shape An array of integers defining the output tensor shape. + * @param randFunction A random number generator function which is called + * for each element in the output tensor. + * @param dtype The data type of the output tensor. Defaults to 'float32'. + */ + function rand_(shape, randFunction, dtype) { + var size = sizeFromShape(shape); + var values = null; + if (dtype == null || dtype === 'float32') { + values = new Float32Array(size); + } + else if (dtype === 'int32') { + values = new Int32Array(size); + } + else if (dtype === 'bool') { + values = new Uint8Array(size); + } + else { + throw new Error("Unknown data type " + dtype); + } + for (var i = 0; i < size; i++) { + values[i] = randFunction(); + } + return ENGINE.makeTensor(values, shape, dtype); + } + /** + * Creates a `tf.Tensor` with values drawn from a multinomial distribution. + * + * ```js + * const probs = tf.tensor([.75, .25]); + * tf.multinomial(probs, 3).print(); + * ``` + * + * @param logits 1D array with unnormalized log-probabilities, or + * 2D array of shape `[batchSize, numOutcomes]`. See the `normalized` + * parameter. + * @param numSamples Number of samples to draw for each row slice. + * @param seed The seed number. + * @param normalized Whether the provided `logits` are normalized true + * probabilities (sum to 1). Defaults to false. + * @return 1D array of shape `[numSamples]`, or 2D array of shape + * `[batchSize, numSamples]`, depending on the rank of the input. + */ + /** @doc {heading: 'Tensors', subheading: 'Random'} */ + function multinomial_(logits, numSamples, seed, normalized) { + if (normalized === void 0) { normalized = false; } + var $logits = convertToTensor(logits, 'logits', 'multinomial'); + var numOutcomes = $logits.size; + var origRank = $logits.rank; + if (numOutcomes < 2) { + throw new Error("Error in multinomial: you need at least 2 outcomes, but got " + + (numOutcomes + ".")); + } + if (origRank > 2) { + throw new Error("Rank of probabilities must be 1 or 2, but is " + origRank); + } + seed = seed || Math.random(); + var logits2D = origRank === 1 ? $logits.as2D(1, -1) : $logits; + var res = ENGINE.runKernelFunc(function (backend) { return backend.multinomial(logits2D, normalized, numSamples, seed); }, { logits2D: logits2D }); + return origRank === 1 ? res.as1D() : res; + } + /** + * Creates a one-hot `tf.Tensor`. The locations represented by `indices` take + * value `onValue` (defaults to 1), while all other locations take value + * `offValue` (defaults to 0). If `indices` is rank `R`, the output has rank + * `R+1` with the last axis of size `depth`. + * + * ```js + * tf.oneHot(tf.tensor1d([0, 1], 'int32'), 3).print(); + * ``` + * + * @param indices `tf.Tensor` of indices with dtype `int32`. + * @param depth The depth of the one hot dimension. + * @param onValue A number used to fill in the output when the index matches + * the location. + * @param offValue A number used to fill in the output when the index does + * not match the location. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function oneHot_(indices, depth, onValue, offValue) { + if (onValue === void 0) { onValue = 1; } + if (offValue === void 0) { offValue = 0; } + if (depth < 2) { + throw new Error("Error in oneHot: depth must be >=2, but it is " + depth); + } + var $indices = convertToTensor(indices, 'indices', 'oneHot', 'int32'); + var outShape = $indices.shape.concat([depth]); + $indices = $indices.flatten(); + var grad = function (dy) { + return { $indices: function () { return zeros($indices.shape, 'float32'); } }; + }; + var result = ENGINE.runKernelFunc(function (backend) { return backend.oneHot($indices, depth, onValue, offValue); }, { $indices: $indices }, grad); + return result.reshape(outShape); + } + /** + * Reshapes a `tf.Tensor` to a given shape. + * + * Given an input tensor, returns a new tensor with the same values as the + * input tensor with shape `shape`. + * + * If one component of shape is the special value -1, the size of that + * dimension is computed so that the total size remains constant. In + * particular, a shape of [-1] flattens into 1-D. At most one component of + * shape can be -1. + * + * If shape is 1-D or higher, then the operation returns a tensor with shape + * shape filled with the values of tensor. In this case, the number of + * elements implied by shape must be the same as the number of elements in + * tensor. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * x.reshape([2, 2]).print(); + * ``` + * + * @param x The input tensor to be reshaped. + * @param shape An array of integers defining the output tensor shape. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function reshape_(x, shape) { + var $x = convertToTensor(x, 'x', 'reshape', null); + shape = inferFromImplicitShape(shape, $x.size); + assert($x.size === sizeFromShape(shape), function () { return 'new shape and old shape must have the same number of elements.'; }); + var grad = function (dy) { + return { x: function () { return dy.reshape($x.shape); } }; + }; + var attrs = { shape: shape }; + return ENGINE.runKernelFunc(function (backend) { return backend.reshape($x, shape); }, { x: $x }, grad, 'Reshape', attrs); + } + /** + * Removes dimensions of size 1 from the shape of a `tf.Tensor`. + * + * ```js + * const x = tf.tensor([1, 2, 3, 4], [1, 1, 4]); + * x.squeeze().print(); + * ``` + * + * @param x The input tensor to be squeezed. + * @param axis An optional list of numbers. If specified, only + * squeezes the dimensions listed. The dimension index starts at 0. It + * is an error to squeeze a dimension that is not 1. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function squeeze_(x, axis) { + var $x = convertToTensor(x, 'x', 'squeeze'); + return reshape($x, squeezeShape($x.shape, axis).newShape); + } + /** + * Casts a `tf.Tensor` to a new dtype. + * + * ```js + * const x = tf.tensor1d([1.5, 2.5, 3]); + * tf.cast(x, 'int32').print(); + * ``` + * @param x The input tensor to be casted. + * @param dtype The dtype to cast the input tensor to. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function cast_(x, dtype) { + var $x = convertToTensor(x, 'x', 'cast'); + // Sanity checks. + if (!isValidDtype(dtype)) { + throw new Error("Failed to cast to unknown dtype " + dtype); + } + if (dtype === 'string' && $x.dtype !== 'string' || + dtype !== 'string' && $x.dtype === 'string') { + throw new Error('Only strings can be casted to strings'); + } + var grad = function (dy) { + return { x: function () { return dy.clone(); } }; + }; + var attrs = { dtype: dtype }; + return ENGINE.runKernelFunc(function (backend) { return backend.cast($x, dtype); }, { x: $x }, grad, 'Cast', attrs); + } + /** + * Construct a tensor by repeating it the number of times given by reps. + * + * This operation creates a new tensor by replicating `input` `reps` + * times. The output tensor's i'th dimension has `input.shape[i] * + * reps[i]` elements, and the values of `input` are replicated + * `reps[i]` times along the i'th dimension. For example, tiling + * `[a, b, c, d]` by `[2]` produces `[a, b, c, d, a, b, c, d]`. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * + * a.tile([2]).print(); // or a.tile([2]) + * ``` + * + * ```js + * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * a.tile([1, 2]).print(); // or a.tile([1, 2]) + * ``` + * @param x The tensor to tile. + * @param reps Determines the number of replications per dimension. + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function tile_(x, reps) { + var parseAs = null; + var $x = convertToTensor(x, 'x', 'tile', parseAs); + assert($x.rank === reps.length, function () { return "Error in transpose: rank of input " + $x.rank + " " + + ("must match length of reps " + reps + "."); }); + var grad = function (dy, saved) { + var $x = saved[0]; + var derX = function () { + var xGrad = zerosLike($x); + // TODO(cais): Maybe reduce memory footprint by avoiding repeated + // slicing. + if ($x.rank === 1) { + for (var i = 0; i < reps[0]; ++i) { + xGrad = xGrad.add(dy.slice([i * $x.shape[0]], [$x.shape[0]])); + } + } + else if ($x.rank === 2) { + for (var i = 0; i < reps[0]; ++i) { + for (var j = 0; j < reps[1]; ++j) { + xGrad = xGrad.add(dy.slice([i * $x.shape[0], j * $x.shape[1]], [$x.shape[0], $x.shape[1]])); + } + } + } + else if ($x.rank === 3) { + for (var i = 0; i < reps[0]; ++i) { + for (var j = 0; j < reps[1]; ++j) { + for (var k = 0; k < reps[2]; ++k) { + xGrad = xGrad.add(dy.slice([i * $x.shape[0], j * $x.shape[1], k * $x.shape[2]], [$x.shape[0], $x.shape[1], $x.shape[2]])); + } + } + } + } + else if ($x.rank === 4) { + for (var i = 0; i < reps[0]; ++i) { + for (var j = 0; j < reps[1]; ++j) { + for (var k = 0; k < reps[2]; ++k) { + for (var l = 0; l < reps[3]; ++l) { + xGrad = xGrad.add(dy.slice([ + i * $x.shape[0], j * $x.shape[1], k * $x.shape[2], + l * $x.shape[3] + ], [$x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]])); + } + } + } + } + } + else { + throw new Error("Gradient for tile operation is not implemented for rank-" + + ($x.rank + " tensors yet.")); + } + return xGrad; + }; + return { $x: derX }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.tile($x, reps); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Pads a `tf.Tensor1D` with a given value and paddings. See `pad` for details. + */ + function pad1d_(x, paddings, constantValue) { + if (constantValue === void 0) { constantValue = 0; } + assert(paddings.length === 2, function () { return 'Invalid number of paddings. Must be length of 2.'; }); + return pad(x, [paddings], constantValue); + } + /** + * Pads a `tf.Tensor2D` with a given value and paddings. See `pad` for details. + */ + function pad2d_(x, paddings, constantValue) { + if (constantValue === void 0) { constantValue = 0; } + assert(paddings.length === 2 && paddings[0].length === 2 && + paddings[1].length === 2, function () { return 'Invalid number of paddings. Must be length of 2 each.'; }); + return pad(x, paddings, constantValue); + } + /** + * Pads a `tf.Tensor3D` with a given value and paddings. See `pad` for details. + */ + function pad3d_(x, paddings, constantValue) { + if (constantValue === void 0) { constantValue = 0; } + assert(paddings.length === 3 && paddings[0].length === 2 && + paddings[1].length === 2 && paddings[2].length === 2, function () { return 'Invalid number of paddings. Must be length of 2 each.'; }); + return pad(x, paddings, constantValue); + } + /** + * Pads a `tf.Tensor4D` with a given value and paddings. See `pad` for details. + */ + function pad4d_(x, paddings, constantValue) { + if (constantValue === void 0) { constantValue = 0; } + assert(paddings.length === 4 && paddings[0].length === 2 && + paddings[1].length === 2 && paddings[2].length === 2 && + paddings[3].length === 2, function () { return 'Invalid number of paddings. Must be length of 2 each.'; }); + return pad(x, paddings, constantValue); + } + /** + * Pads a `tf.Tensor` with a given value and paddings. + * + * This operation currently only implements the `CONSTANT` mode. + * + * Also available are stricter rank-specific methods with the same signature + * as this method that assert that `paddings` is of given length. + * - `tf.pad1d` + * - `tf.pad2d` + * - `tf.pad3d` + * - `tf.pad4d` + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * x.pad([[1, 2]]).print(); + * ``` + * @param x The tensor to pad. + * @param paddings An array of length `R` (the rank of the tensor), where + * each element is a length-2 tuple of ints `[padBefore, padAfter]`, + * specifying how much to pad along each dimension of the tensor. + * @param constantValue The pad value to use. Defaults to 0. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function pad_(x, paddings, constantValue) { + if (constantValue === void 0) { constantValue = 0; } + var $x = convertToTensor(x, 'x', 'pad'); + if ($x.rank === 0) { + throw new Error('pad(scalar) is not defined. Pass non-scalar to pad'); + } + var grad = function (dy) { + // Pad introduces values around the original tensor, so the gradient + // slices the original shape out of the gradient. + var begin = paddings.map(function (p) { return p[0]; }); + return { x: function () { return dy.slice(begin, $x.shape); } }; + }; + var attrs = { paddings: paddings, constantValue: constantValue }; + return ENGINE.runKernelFunc(function (backend) { return backend.pad($x, paddings, constantValue); }, { x: $x }, grad, 'PadV2', attrs); + } + /** + * Stacks a list of rank-`R` `tf.Tensor`s into one rank-`(R+1)` `tf.Tensor`. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * const c = tf.tensor1d([5, 6]); + * tf.stack([a, b, c]).print(); + * ``` + * + * @param tensors A list of tensor objects with the same shape and dtype. + * @param axis The axis to stack along. Defaults to 0 (the first dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function stack_(tensors, axis) { + if (axis === void 0) { axis = 0; } + var $tensors = convertToTensorArray(tensors, 'tensors', 'stack'); + assert($tensors.length >= 1, function () { return 'Pass at least one tensor to tf.stack'; }); + if ($tensors.length === 1) { + return $tensors[0].expandDims(axis); + } + var rank = $tensors[0].rank; + var shape = $tensors[0].shape; + var dtype = $tensors[0].dtype; + assert(axis <= rank, function () { return 'Axis must be <= rank of the tensor'; }); + $tensors.forEach(function (t) { + assertShapesMatch(shape, t.shape, 'All tensors passed to stack must have matching shapes'); + }); + $tensors.forEach(function (t) { + assert(dtype === t.dtype, function () { return 'All tensors passed to stack must have matching dtypes'; }); + }); + var expandedTensors = $tensors.map(function (t) { return t.expandDims(axis); }); + return concat(expandedTensors, axis); + } + /** + * This operation reshapes the "batch" dimension 0 into `M + 1` dimensions of + * shape `blockShape + [batch]`, interleaves these blocks back into the grid + * defined by the spatial dimensions `[1, ..., M]`, to obtain a result with + * the same rank as the input. The spatial dimensions of this intermediate + * result are then optionally cropped according to `crops` to produce the + * output. This is the reverse of `tf.spaceToBatchND`. See below for a precise + * description. + * + * ```js + * const x = tf.tensor4d([1, 2, 3, 4], [4, 1, 1, 1]); + * const blockShape = [2, 2]; + * const crops = [[0, 0], [0, 0]]; + * + * x.batchToSpaceND(blockShape, crops).print(); + * ``` + * + * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + + * remainingShape`, where spatialShape has `M` dimensions. + * @param blockShape A 1-D array. Must have shape `[M]`, all values must + * be >= 1. + * @param crops A 2-D array. Must have shape `[M, 2]`, all values must be >= 0. + * `crops[i] = [cropStart, cropEnd]` specifies the amount to crop from input + * dimension `i + 1`, which corresponds to spatial dimension `i`. It is required + * that `cropStart[i] + cropEnd[i] <= blockShape[i] * inputShape[i + 1]` + * + * This operation is equivalent to the following steps: + * + * 1. Reshape `x` to `reshaped` of shape: `[blockShape[0], ..., + * blockShape[M-1], batch / prod(blockShape), x.shape[1], ..., + * x.shape[N-1]]` + * + * 2. Permute dimensions of `reshaped`to produce `permuted` of shape `[batch / + * prod(blockShape),x.shape[1], blockShape[0], ..., x.shape[M], + * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` + * + * 3. Reshape `permuted` to produce `reshapedPermuted` of shape `[batch / + * prod(blockShape),x.shape[1] * blockShape[0], ..., x.shape[M] * + * blockShape[M-1],x.shape[M+1], ..., x.shape[N-1]]` + * + * 4. Crop the start and end of dimensions `[1, ..., M]` of `reshapedPermuted` + * according to `crops` to produce the output of shape: `[batch / + * prod(blockShape),x.shape[1] * blockShape[0] - crops[0,0] - crops[0,1], + * ..., x.shape[M] * blockShape[M-1] - crops[M-1,0] - + * crops[M-1,1],x.shape[M+1], ..., x.shape[N-1]]` + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function batchToSpaceND_(x, blockShape, crops) { + var $x = convertToTensor(x, 'x', 'batchToSpaceND'); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + assert($x.rank >= 1 + blockShape.length, function () { return "input rank is " + $x.rank + " but should be > than blockShape.length " + blockShape.length; }); + assert(crops.length === blockShape.length, function () { return "crops.length is " + crops.length + " but should be equal to blockShape.length " + blockShape.length; }); + assert($x.shape[0] % prod === 0, function () { return "input tensor batch is " + $x.shape[0] + " but is not divisible by the product of " + + ("the elements of blockShape " + blockShape.join(' * ') + " === " + prod); }); + var grad = function (dy) { + return { $x: function () { return dy.spaceToBatchND(blockShape, crops); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.batchToSpaceND($x, blockShape, crops); }, { $x: $x }, grad); + } + /** + * This operation divides "spatial" dimensions `[1, ..., M]` of the input into + * a grid of blocks of shape `blockShape`, and interleaves these blocks with + * the "batch" dimension (0) such that in the output, the spatial + * dimensions `[1, ..., M]` correspond to the position within the grid, + * and the batch dimension combines both the position within a spatial block + * and the original batch position. Prior to division into blocks, + * the spatial dimensions of the input are optionally zero padded + * according to `paddings`. See below for a precise description. + * + * ```js + * const x = tf.tensor4d([1, 2, 3, 4], [1, 2, 2, 1]); + * const blockShape = [2, 2]; + * const paddings = [[0, 0], [0, 0]]; + * + * x.spaceToBatchND(blockShape, paddings).print(); + * ``` + * + * @param x A `tf.Tensor`. N-D with `x.shape` = `[batch] + spatialShape + + * remainingShape`, where spatialShape has `M` dimensions. + * @param blockShape A 1-D array. Must have shape `[M]`, all values must + * be >= 1. + * @param paddings A 2-D array. Must have shape `[M, 2]`, all values must be >= + * 0. `paddings[i] = [padStart, padEnd]` specifies the amount to zero-pad + * from input dimension `i + 1`, which corresponds to spatial dimension `i`. It + * is required that + * `(inputShape[i + 1] + padStart + padEnd) % blockShape[i] === 0` + * + * This operation is equivalent to the following steps: + * + * 1. Zero-pad the start and end of dimensions `[1, ..., M]` of the input + * according to `paddings` to produce `padded` of shape paddedShape. + * + * 2. Reshape `padded` to `reshapedPadded` of shape: + * `[batch] + [paddedShape[1] / blockShape[0], blockShape[0], ..., + * paddedShape[M] / blockShape[M-1], blockShape[M-1]] + remainingShape` + * + * 3. Permute dimensions of `reshapedPadded` to produce `permutedReshapedPadded` + * of shape: `blockShape + [batch] + [paddedShape[1] / blockShape[0], ..., + * paddedShape[M] / blockShape[M-1]] + remainingShape` + * + * 4. Reshape `permutedReshapedPadded` to flatten `blockShape` into the + * batch dimension, producing an output tensor of shape: + * `[batch * prod(blockShape)] + [paddedShape[1] / blockShape[0], ..., + * paddedShape[M] / blockShape[M-1]] + remainingShape` + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function spaceToBatchND_(x, blockShape, paddings) { + var $x = convertToTensor(x, 'x', 'spaceToBatchND'); + assert($x.rank >= 1 + blockShape.length, function () { return "input rank " + $x.rank + " should be > than [blockShape] " + blockShape.length; }); + assert(paddings.length === blockShape.length, function () { return "paddings.shape[0] " + paddings.length + " must be equal to [blockShape] " + blockShape.length; }); + assert($x.shape.reduce(function (a, b, i) { + if (i > 0 && i <= blockShape.length) { + return a && + ((b + paddings[i - 1][0] + paddings[i - 1][1]) % + blockShape[i - 1] === + 0); + } + return a; + }, true), function () { return "input spatial dimensions " + $x.shape.slice(1) + " with paddings " + paddings.toString() + " must be divisible by blockShapes " + blockShape.toString(); }); + var grad = function (dy) { + return { $x: function () { return dy.batchToSpaceND(blockShape, paddings); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.spaceToBatchND($x, blockShape, paddings); }, { $x: $x }, grad); + } + /** + * Unstacks a `tf.Tensor` of rank-`R` into a list of rank-`(R-1)` `tf.Tensor`s. + * + * ```js + * const a = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * tf.unstack(a).forEach(tensor => tensor.print()); + * ``` + * + * @param x A tensor object. + * @param axis The axis to unstack along. Defaults to 0 (the first dim). + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function unstack_(x, axis) { + if (axis === void 0) { axis = 0; } + axis = axis || 0; + var $x = convertToTensor(x, 'x', 'unstack'); + assert(axis >= -$x.shape.length && axis < $x.shape.length, function () { + return "Axis = " + axis + " is not in [-" + $x.shape.length + ", " + $x.shape.length + ")"; + }); + if (axis < 0) { + axis += $x.shape.length; + } + var grad = function (dy) { + return { $x: function () { return stack(dy, axis); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.unstack($x, axis); }, { $x: $x }, grad); + } + /** + * Computes the cumulative sum of a `tf.Tensor` along `axis`. + * + * ```js + * const x = tf.tensor([1, 2, 3, 4]); + * x.cumsum().print(); + * ``` + * ```js + * const x = tf.tensor([[1, 2], [3, 4]]); + * x.cumsum().print(); + * ``` + * + * @param x The input tensor to be summed. + * @param axis The axis along which to sum. Optional. Defaults to 0. + * @param exclusive Whether to perform exclusive cumulative sum. Optional. + * Defaults to false. If set to true then the sum of each tensor entry + * does not include its own value, but only the values previous to it + * along the specified axis. + * @param reverse Whether to sum in the opposite direction. Optional. + * Defaults to false. + */ + /** @doc {heading: 'Operations', subheading: 'Scan'} */ + function cumsum_(x, axis, exclusive, reverse) { + if (axis === void 0) { axis = 0; } + if (exclusive === void 0) { exclusive = false; } + if (reverse === void 0) { reverse = false; } + var $x = convertToTensor(x, 'x', 'cumsum'); + axis = axis | 0; + var permutation = getAxesPermutation([axis], $x.rank); + var permutedX = $x; + if (permutation != null) { + permutedX = $x.transpose(permutation); + } + var permutedAxis = getInnerMostAxes(1, $x.rank)[0]; + var grad = function (dy) { + return { permutedX: function () { return dy.cumsum(axis, exclusive, !reverse); } }; + }; + var value = ENGINE.runKernelFunc(function (backend) { return backend.cumsum(permutedX, permutedAxis, exclusive, reverse); }, { permutedX: permutedX }, grad); + if (permutation != null) { + value = value.transpose(permutation); + } + return value; + } + /** + * Returns a `tf.Tensor` that has expanded rank, by inserting a dimension + * into the tensor's shape. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * const axis = 1; + * x.expandDims(axis).print(); + * ``` + * + * @param x The input tensor whose dimensions to be expanded. + * @param axis The dimension index at which to insert shape of `1`. Defaults + * to 0 (the first dimension). + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function expandDims_(x, axis) { + if (axis === void 0) { axis = 0; } + var parseAs = null; + var $x = convertToTensor(x, 'x', 'expandDims', parseAs); + assert(axis <= $x.rank, function () { return 'Axis must be <= rank of the tensor'; }); + var newShape = $x.shape.slice(); + if (axis < 0) { + // Negative value is counted from the tail of rank. + assert(-($x.rank + 1) <= axis, function () { return "Axis must be in the interval [" + -($x.rank + 1) + ", " + $x.rank + "]"; }); + axis = $x.rank + axis + 1; + } + newShape.splice(axis, 0, 1); + return reshape($x, newShape); + } + /** + * Rearranges data from depth into blocks of spatial data. More specifically, + * this op outputs a copy of the input tensor where values from the `depth` + * dimension are moved in spatial blocks to the `height` and `width` dimensions. + * The attr `blockSize` indicates the input block size and how the data is + * moved. + * + * - Chunks of data of size `blockSize * blockSize` from depth are rearranged + * into non-overlapping blocks of size `blockSize x blockSize` + * + * - The width the output tensor is `inputWidth * blockSize`, whereas the + * height is `inputHeight * blockSize` + * + * - The Y, X coordinates within each block of the output image are determined + * by the high order component of the input channel index + * + * - The depth of the input tensor must be divisible by `blockSize * + * blockSize` + * + * The `dataFormat` attr specifies the layout of the input and output tensors + * with the following options: "NHWC": [ `batch, height, width, channels` ] + * "NCHW": [ `batch, channels, height, width` ] + * + * ```js + * const x = tf.tensor4d([1, 2, 3, 4], [1, 1, 1, 4]); + * const blockSize = 2; + * const dataFormat = "NHWC"; + * + * tf.depthToSpace(x, blockSize, dataFormat).print(); + * ``` + * + * @param x The input tensor of rank 4 + * @param blockSIze An `int` that is `>= 2`. The size of the spatial block + * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to "NHWC" + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function depthToSpace_(x, blockSize, dataFormat) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + var $x = convertToTensor(x, 'x', 'depthToSpace'); + var inputHeight = (dataFormat === 'NHWC') ? $x.shape[1] : $x.shape[2]; + var inputWidth = (dataFormat === 'NHWC') ? $x.shape[2] : $x.shape[3]; + var inputDepth = (dataFormat === 'NHWC') ? $x.shape[3] : $x.shape[1]; + assert(inputHeight * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputHeight + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; }); + assert(inputWidth * blockSize >= 0, function () { return "Negative dimension size caused by overflow when multiplying\n " + inputWidth + " and " + blockSize + " for depthToSpace with input shape\n " + $x.shape; }); + assert((inputDepth % (blockSize * blockSize) === 0), function () { return "Dimension size must be evenly divisible by " + blockSize * blockSize + " but is " + inputDepth + " for depthToSpace with input shape " + $x.shape; }); + return ENGINE.runKernelFunc(function (backend) { return backend.depthToSpace($x, blockSize, dataFormat); }, { $x: $x }); + } + /** + * Computes the difference between two lists of numbers. + * + * Given a Tensor `x` and a Tensor `y`, this operation returns a Tensor `out` + * that represents all values that are in `x` but not in `y`. The returned + * Tensor `out` is sorted in the same order that the numbers appear in `x` + * (duplicates are preserved). This operation also returns a Tensor indices that + * represents the position of each out element in `x`. In other words: + * + * `out[i] = x[idx[i]] for i in [0, 1, ..., out.length - 1]` + * + * ```js + * const x = [1, 2, 3, 4, 5, 6]; + * const y = [1, 3, 5]; + * + * const [out, indices] = await tf.setdiff1dAsync(x, y); + * out.print(); // [2, 4, 6] + * indices.print(); // [1, 3, 5] + * ``` + * + * @param x 1-D Tensor. Values to keep. + * @param y 1-D Tensor. Must have the same type as x. Values to exclude in the + * output. + * @returns Promise of Tensor tuple [out, indices]. + * out: Tensor with the same type as x. + * indices: A Tensor of type int32. + */ + /** @doc {heading: 'Tensors', subheading: 'Transformations'} */ + function setdiff1dAsync_(x, y) { + return __awaiter(this, void 0, void 0, function () { + var $x, $y, xVals, yVals, ySet, outputSize, i, buffer, indices, i, p; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + $x = convertToTensor(x, 'x', 'setdiff1d'); + $y = convertToTensor(y, 'y', 'setdiff1d'); + assert($x.dtype === $y.dtype, function () { return "x and y should have the same dtype, but got x (" + $x.dtype + ") and y (" + $y.dtype + ")."; }); + assert($x.rank === 1, function () { return "x should be 1D tensor, but got x (" + $x.shape + ")."; }); + assert($y.rank === 1, function () { return "y should be 1D tensor, but got y (" + $y.shape + ")."; }); + return [4 /*yield*/, $x.data()]; + case 1: + xVals = _a.sent(); + return [4 /*yield*/, $y.data()]; + case 2: + yVals = _a.sent(); + ySet = new Set(yVals); + outputSize = 0; + for (i = 0; i < xVals.length; i++) { + if (!ySet.has(xVals[i])) { + outputSize++; + } + } + buffer = new TensorBuffer([outputSize], $x.dtype); + indices = new TensorBuffer([outputSize], 'int32'); + for (i = 0, p = 0; i < xVals.length; i++) { + if (!ySet.has(xVals[i])) { + buffer.values[p] = xVals[i]; + indices.values[p] = i; + p++; + } + } + return [2 /*return*/, [buffer.toTensor(), indices.toTensor()]]; + } + }); + }); + } + /** + * Creates an empty `tf.TensorBuffer` with the specified `shape` and `dtype`. + * + * The values are stored in CPU as `TypedArray`. Fill the buffer using + * `buffer.set()`, or by modifying directly `buffer.values`. + * + * When done, call `buffer.toTensor()` to get an immutable `tf.Tensor` with + * those values. + * + * ```js + * // Create a buffer and set values at particular indices. + * const buffer = tf.buffer([2, 2]); + * buffer.set(3, 0, 0); + * buffer.set(5, 1, 0); + * + * // Convert the buffer back to a tensor. + * buffer.toTensor().print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param dtype The dtype of the buffer. Defaults to 'float32'. + * @param values The values of the buffer as `TypedArray`. Defaults to + * zeros. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function buffer(shape, dtype, values) { + if (dtype === void 0) { dtype = 'float32'; } + dtype = dtype || 'float32'; + assertNonNegativeIntegerDimensions(shape); + return new TensorBuffer(shape, dtype, values); + } + /** + * Prints information about the `tf.Tensor` including its data. + * + * ```js + * const verbose = true; + * tf.tensor2d([1, 2, 3, 4], [2, 2]).print(verbose); + * ``` + * @param x The tensor to be printed. + * @param verbose Whether to print verbose information about the ` Tensor`, + * including dtype and size. + */ + /** @doc {heading: 'Tensors', subheading: 'Creation'} */ + function print(x, verbose) { + if (verbose === void 0) { verbose = false; } + console.log(x.toString(verbose)); + } + var batchToSpaceND = op({ batchToSpaceND_: batchToSpaceND_ }); + var cast = op({ cast_: cast_ }); + var clone = op({ clone_: clone_ }); + var cumsum = op({ cumsum_: cumsum_ }); + var depthToSpace = op({ depthToSpace_: depthToSpace_ }); + var expandDims = op({ expandDims_: expandDims_ }); + var eye = op({ eye_: eye_ }); + var multinomial = op({ multinomial_: multinomial_ }); + var oneHot = op({ oneHot_: oneHot_ }); + var pad = op({ pad_: pad_ }); + var pad1d = op({ pad1d_: pad1d_ }); + var pad2d = op({ pad2d_: pad2d_ }); + var pad3d = op({ pad3d_: pad3d_ }); + var pad4d = op({ pad4d_: pad4d_ }); + var rand = op({ rand_: rand_ }); + var randomNormal = op({ randomNormal_: randomNormal_ }); + var randomGamma = op({ randomGamma_: randomGamma_ }); + var randomUniform = op({ randomUniform_: randomUniform_ }); + var reshape = op({ reshape_: reshape_ }); + var spaceToBatchND = op({ spaceToBatchND_: spaceToBatchND_ }); + var squeeze = op({ squeeze_: squeeze_ }); + var stack = op({ stack_: stack_ }); + var tile = op({ tile_: tile_ }); + var truncatedNormal = op({ truncatedNormal_: truncatedNormal_ }); + var unstack = op({ unstack_: unstack_ }); + var setdiff1dAsync = setdiff1dAsync_; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Gets the new shape of the input Tensor after it's been reshaped + * to: + * [blockShape[0], ..., blockShape[M-1], batch / prod(blockShape), + * inputShape[1], ..., inputShape[N-1]] + * + * See step 1: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd + */ + function getReshaped(inputShape, blockShape, prod, batchToSpace) { + if (batchToSpace === void 0) { batchToSpace = true; } + var reshaped = []; + if (batchToSpace) { + reshaped = reshaped.concat(blockShape.slice(0)); + reshaped.push(inputShape[0] / prod); + reshaped = reshaped.concat(inputShape.slice(1)); + } + else { + reshaped = reshaped.concat(inputShape[0]); + var spatialLength = blockShape.length; + for (var i = 0; i < spatialLength; ++i) { + reshaped = + reshaped.concat([inputShape[i + 1] / blockShape[i], blockShape[i]]); + } + reshaped = reshaped.concat(inputShape.slice(spatialLength + 1)); + } + return reshaped; + } + /** + * Gets the permutation that will transpose the dimensions of the + * reshaped tensor to shape: + * + * [batch / prod(block_shape),inputShape[1], blockShape[0], ..., + * inputShape[M], blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]] + * + * see step 2: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd + */ + function getPermuted(reshapedRank, blockShapeRank, batchToSpace) { + if (batchToSpace === void 0) { batchToSpace = true; } + var permuted = []; + if (batchToSpace) { + permuted.push(blockShapeRank); + for (var i = blockShapeRank + 1; i < reshapedRank; ++i) { + if (i <= 2 * blockShapeRank) { + permuted.push(i); + permuted.push(i - (blockShapeRank + 1)); + } + else { + permuted.push(i); + } + } + } + else { + var permutedBeforeBatch = []; + var permutedAfterBatch = []; + for (var i = 1; i < reshapedRank; ++i) { + if (i >= blockShapeRank * 2 + 1 || i % 2 === 1) { + permutedAfterBatch.push(i); + } + else { + permutedBeforeBatch.push(i); + } + } + permuted.push.apply(permuted, permutedBeforeBatch); + permuted.push(0); + permuted.push.apply(permuted, permutedAfterBatch); + } + return permuted; + } + /** + * Gets the shape of the reshaped and permuted input Tensor before any cropping + * is applied. The new shape will be: + * + * [batch / prod(blockShape),inputShape[1] * blockShape[0], ..., + * inputShape[M] * blockShape[M-1],inputShape[M+1], ..., inputShape[N-1]] + * + * See step 3: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd + */ + function getReshapedPermuted(inputShape, blockShape, prod, batchToSpace) { + if (batchToSpace === void 0) { batchToSpace = true; } + var reshapedPermuted = []; + if (batchToSpace) { + reshapedPermuted.push(inputShape[0] / prod); + } + else { + reshapedPermuted.push(inputShape[0] * prod); + } + for (var i = 1; i < inputShape.length; ++i) { + if (i <= blockShape.length) { + if (batchToSpace) { + reshapedPermuted.push(blockShape[i - 1] * inputShape[i]); + } + else { + reshapedPermuted.push(inputShape[i] / blockShape[i - 1]); + } + } + else { + reshapedPermuted.push(inputShape[i]); + } + } + return reshapedPermuted; + } + /** + * Converts the crops argument into the beginning coordinates of a slice + * operation. + */ + function getSliceBeginCoords(crops, blockShape) { + var sliceBeginCoords = [0]; + for (var i = 0; i < blockShape; ++i) { + sliceBeginCoords.push(crops[i][0]); + } + return sliceBeginCoords; + } + /** + * Converts the crops argument into the size of a slice operation. When + * combined with getSliceBeginCoords this function allows the reshaped and + * permuted Tensor to be cropped to its final output shape of: + * + * inputShape[1] * blockShape[0] - crops[0,0] - crops[0,1], ..., + * inputShape[M] * blockShape[M-1] -crops[M-1,0] - + * crops[M-1,1],inputShape[M+1], ..., inputShape[N-1]] + * + * See step 4: https://www.tensorflow.org/api_docs/python/tf/batch_to_space_nd + */ + function getSliceSize(uncroppedShape, crops, blockShape) { + var sliceSize = uncroppedShape.slice(0, 1); + for (var i = 0; i < blockShape; ++i) { + sliceSize.push(uncroppedShape[i + 1] - crops[i][0] - crops[i][1]); + } + return sliceSize; + } + + /** + * Validate gather nd inputs. + * + * @param tensor The tensor contains the source values. + * @param indices The tensor contains the indices to slice the source. + * + * @returns [resultShape, numUpdates, sliceSize, strides] + */ + function prepareAndValidate(tensor, indices) { + if (tensor.rank < 1) { + throw new Error('tf.gatherND() expects the input to be rank 1 or higher,' + + (" but the rank was " + tensor.rank + ".")); + } + if (indices.rank < 1) { + throw new Error('tf.gatherND() expects the indices to be rank 1 or higher,' + + (" but the rank was " + indices.rank + ".")); + } + if (indices.dtype !== 'int32') { + throw new Error('tf.gatherND() expects the indices to be int32 type,' + + (" but the dtype was " + indices.dtype + ".")); + } + if (indices.shape[indices.rank - 1] > tensor.rank) { + throw new Error('index innermost dimension length must be <= tensor rank; saw: ' + + (indices.shape[indices.rank - 1] + " vs. " + tensor.rank)); + } + if (tensor.size === 0) { + throw new Error('Requested more than 0 entries, but input is empty.' + + (" Input shape: " + tensor.shape + ".")); + } + var indicesShape = indices.shape; + var sliceRank = indicesShape[indicesShape.length - 1]; + // The result shape is + // indices.shape[:-1] + params.shape[indices.shape[-1]:] + var nResult = 1; + for (var i = 0; i < indicesShape.length - 1; ++i) { + nResult *= indicesShape[i]; + } + var inputShape = tensor.shape; + var resultShape = indicesShape.slice(); + resultShape.pop(); + var sliceSize = 1; + for (var i = sliceRank; i < tensor.rank; ++i) { + sliceSize *= inputShape[i]; + resultShape.push(inputShape[i]); + } + var strides = computeStrides(tensor.shape).map(function (stride) { return stride / sliceSize; }).concat([1]).slice(0, sliceRank); + return [resultShape, nResult, sliceSize, strides]; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PARALLELIZE_THRESHOLD = 30; + function computeOptimalWindowSize(inSize) { + if (inSize <= PARALLELIZE_THRESHOLD) { + return inSize; + } + return nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); + } + + /** + * Check whether updates.shape = indices.shape[:batchDim] + + * shape[sliceDim:] + * + * @param x The input tensor. + */ + function validateUpdateShape(shape, indices, updates) { + var sliceDim = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1; + var batchDim = (indices.rank > 1) ? indices.rank - 1 : 1; + var shapeError = 'Must have updates.shape = indices.shape[:batchDim] + ' + + ("shape[sliceDim:], got updates.shape: " + updates.shape) + + (", indices.shape: " + indices.shape + ", shape: " + shape) + + (", sliceDim: " + sliceDim + ", and batchDim: " + batchDim + "."); + if (updates.rank < batchDim) { + throw new Error(shapeError + (" update.rank < " + batchDim + ". ")); + } + if (shape.length < sliceDim + (updates.rank - batchDim)) { + throw new Error(shapeError + + (" Output shape length < " + (sliceDim + (updates.rank - batchDim)))); + } + if (updates.rank !== batchDim + shape.length - sliceDim) { + throw new Error(shapeError + (" update.rank != " + (batchDim + shape.length - sliceDim))); + } + for (var d = 0; d < batchDim; ++d) { + if (updates.shape[d] !== indices.shape[d]) { + throw new Error(shapeError + + (" updates.shape[" + d + "] (" + updates.shape[d] + ") != indices.shape[" + d + "] (" + indices.shape[d] + ").")); + } + } + for (var d = 0; d < updates.rank - batchDim; ++d) { + if (updates.shape[d + batchDim] !== shape[d + sliceDim]) { + throw new Error(shapeError + + (" updates.shape[" + (d + batchDim) + "] (" + updates.shape[d + batchDim] + ") != shape[" + (d + batchDim) + "] (" + shape[d + batchDim] + ")")); + } + } + } + /** + * Validate scatter nd inputs. + * + * @param update The tensor contains the update values. + * @param indices The tensor contains the indices for the update values. + * @param shape The shape of the output tensor. + */ + function validateInput(updates, indices, shape) { + if (indices.rank < 1) { + throw new Error('tf.scatterND() expects the indices to be rank 1 or higher,' + + (" but the rank was " + indices.rank + ".")); + } + if (updates.rank < 1) { + throw new Error('tf.scatterND() expects the updates to be rank 1 or higher,' + + (" but the rank was " + updates.rank + ".")); + } + if (indices.dtype !== 'int32') { + throw new Error("The dtype of 'indices' should be int32, but got dtype: " + indices.dtype); + } + if (shape.length < 1) { + throw new Error("Output rank must be greater or equal to 1, but got shape: " + shape); + } + if (shape.length === 0) { + if (indices.size === 0) { + throw new Error("Indices specified for empty output. indices shape: " + indices.shape); + } + if (updates.size === 0) { + throw new Error("Updates specified for empty output. updates shape: " + updates.shape); + } + } + validateUpdateShape(shape, indices, updates); + } + /** + * Calculate the shape information for the output. + * + * @param update The tensor contains the update values. + * @param indices The tensor contains the indices for the update values. + * @param shape The shape of the output tensor. + * + * @returns ScatterShapeInfo + */ + function calculateShapes(updates, indices, shape) { + // Calculate the number of dimensions in indices + var sliceRank = (indices.rank > 1) ? indices.shape[indices.rank - 1] : 1; + // Calculate the number of elements that make up each slice of our updated + // tensor. This allows us to work with flattened tensors and copy over whole + // slices at a time. + var totalNd = shape.length; + var sliceSize = 1; + for (var i = sliceRank; i < totalNd; ++i) { + sliceSize *= shape[i]; + } + var safeSliceDim = (sliceRank < 1) ? 1 : sliceRank; + var numUpdates = indices.size / safeSliceDim; + var strides = computeStrides(shape.slice(0, sliceRank)).concat([1]); + var outputSize = sizeFromShape(shape); + return { sliceRank: sliceRank, numUpdates: numUpdates, sliceSize: sliceSize, strides: strides, outputSize: outputSize }; + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function segOpComputeOptimalWindowSize(inSize, numSegments) { + var done = false; + var res; + if (inSize <= PARALLELIZE_THRESHOLD) { + res = inSize; + done = true; + } + else { + res = nearestDivisor(inSize, Math.floor(Math.sqrt(inSize))); + } + while (!done) { + if (res > numSegments || res === inSize) { + done = true; + } + else { + res = nearestDivisor(inSize, res + 1); + } + } + return res; + } + function computeOutShape$1(aShape, axis, numSegments) { + var outShape = []; + var rank = aShape.length; + for (var dim = 0; dim < rank; dim++) { + if (dim !== axis) { + outShape.push(aShape[dim]); + } + else { + outShape.push(numSegments); + } + } + return outShape; + } + function collectGatherOpShapeInfo(x, indices, axis) { + var dimSize = x.shape[axis]; + var outputShape = []; + var batchSize = 1; + var sliceSize = 1; + for (var i = 0; i < axis; i++) { + outputShape.push(x.shape[i]); + batchSize *= x.shape[i]; + } + for (var i = 0; i < indices.rank; i++) { + outputShape.push(indices.shape[i]); + } + for (var i = axis + 1; i < x.rank; i++) { + outputShape.push(x.shape[i]); + sliceSize *= x.shape[i]; + } + return { batchSize: batchSize, sliceSize: sliceSize, dimSize: dimSize, outputShape: outputShape }; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function assertParamsValid(input, begin, size) { + assert(input.rank === begin.length, function () { return "Error in slice" + input.rank + "D: Length of begin " + begin + " must " + + ("match the rank of the array (" + input.rank + ")."); }); + assert(input.rank === size.length, function () { return "Error in slice" + input.rank + "D: Length of size " + size + " must " + + ("match the rank of the array (" + input.rank + ")."); }); + var _loop_1 = function (i) { + assert(begin[i] + size[i] <= input.shape[i], function () { return "Error in slice" + input.rank + "D: begin[" + i + "] + size[" + i + "] " + + ("(" + (begin[i] + size[i]) + ") would overflow input.shape[" + i + "] (" + input.shape[i] + ")"); }); + }; + for (var i = 0; i < input.rank; ++i) { + _loop_1(i); + } + } + /** Converts a binary mask to an array of axes. Used in stridedSlice(). */ + function maskToAxes(mask) { + var axes = []; + var axis = 0; + while (mask > 0) { + if (mask & 1) { + axes.push(axis); + } + mask /= 2; + axis++; + } + return axes; + } + /** Computes the output shape given the strided slice params. */ + function computeOutShape$2(begin, end, strides) { + var size = []; + for (var axis = 0; axis < begin.length; axis++) { + size[axis] = Math.ceil((end[axis] - begin[axis]) / strides[axis]); + } + return size; + } + function startForAxis(beginMask, startIndices, strides, inputShape, axis) { + // Begin with the specified index + var start = startIndices[axis]; + var stride = strides[axis] || 1; + // Check the axis bit from right of beginMask or the begin index is not set + // for the axis. + if (beginMask & 1 << axis || start == null) { + if (stride > 0) { + // Forward iteration - use the first element. These values will get + // clamped below (Note: We could have set them to 0 and axis_size-1, but + // use lowest() and max() to maintain symmetry with StopForAxis()) + start = Number.MIN_SAFE_INTEGER; + } + else { + // Backward iteration - use the last element. + start = Number.MAX_SAFE_INTEGER; + } + } + // Handle negative indices + var axisSize = inputShape[axis]; + if (start < 0) { + start += axisSize; + } + // Clamping + start = clamp(0, start, axisSize - 1); + return start; + } + function stopForAxis(endMask, stopIndices, strides, inputShape, axis) { + // Begin with the specified index + var stop = stopIndices[axis]; + var stride = strides[axis] || 1; + // Check the axis bit from right of endMask or if the stop index is not set + // for this axis. + if (endMask & (1 << axis) || stop == null) { + if (stride > 0) { + // Forward iteration - use the last element. These values will get + // clamped below + stop = Number.MAX_SAFE_INTEGER; + } + else { + // Backward iteration - use the first element. + stop = Number.MIN_SAFE_INTEGER; + } + } + // Handle negative indices + var axisSize = inputShape[axis]; + if (stop < 0) { + stop += axisSize; + } + // Clamping + // Because the end index points one past the last element, we need slightly + // different clamping ranges depending on the direction. + if (stride > 0) { + // Forward iteration + stop = clamp(0, stop, axisSize); + } + else { + // Backward iteration + stop = clamp(-1, stop, axisSize - 1); + } + return stop; + } + /** + * Returns true if the slice occupies a continous set of elements in the + * 'flat' space. + */ + function isSliceContinous(shape, begin, size) { + // Index of the first axis that has size > 1. + var firstNonOneAxis = size.length; + for (var i = 0; i < size.length; i++) { + if (size[i] > 1) { + firstNonOneAxis = i; + break; + } + } + for (var i = firstNonOneAxis + 1; i < size.length; i++) { + if (begin[i] > 0 || size[i] !== shape[i]) { + return false; + } + } + return true; + } + function computeFlatOffset(begin, strides) { + var flatOffset = begin.length > 0 ? begin[begin.length - 1] : 1; + for (var i = 0; i < begin.length - 1; i++) { + flatOffset += begin[i] * strides[i]; + } + return flatOffset; + } + + var slice_util = /*#__PURE__*/Object.freeze({ + assertParamsValid: assertParamsValid, + maskToAxes: maskToAxes, + computeOutShape: computeOutShape$2, + startForAxis: startForAxis, + stopForAxis: stopForAxis, + isSliceContinous: isSliceContinous, + computeFlatOffset: computeFlatOffset + }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Provided `f(x)`, returns another function `g(x, dy?)`, which gives the + * gradient of `f(x)` with respect to `x`. + * + * If `dy` is provided, the gradient of `f(x).mul(dy).sum()` with respect to + * `x` is computed instead. `f(x)` must take a single tensor `x` and return a + * single tensor `y`. If `f()` takes multiple inputs, use `tf.grads` instead. + * + * ```js + * // f(x) = x ^ 2 + * const f = x => x.square(); + * // f'(x) = 2x + * const g = tf.grad(f); + * + * const x = tf.tensor1d([2, 3]); + * g(x).print(); + * ``` + * + * ```js + * // f(x) = x ^ 3 + * const f = x => x.pow(tf.scalar(3, 'int32')); + * // f'(x) = 3x ^ 2 + * const g = tf.grad(f); + * // f''(x) = 6x + * const gg = tf.grad(g); + * + * const x = tf.tensor1d([2, 3]); + * gg(x).print(); + * ``` + * + * @param f The function f(x), to compute gradient for. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function grad(f) { + assert(isFunction(f), function () { return 'The f passed in grad(f) must be a function'; }); + return function (x, dy) { + // x can be of any dtype, thus null as the last argument. + var $x = convertToTensor(x, 'x', 'tf.grad', null); + var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grad') : null; + return ENGINE.tidy(function () { + var _a = ENGINE.gradients(function () { return f($x); }, [$x], $dy), value = _a.value, grads = _a.grads; + if ($dy != null) { + assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grad(f)(x, dy) must match the shape ' + + 'returned by f(x)'); + } + checkGrads(grads); + return grads[0]; + }); + }; + } + /** + * Provided `f(x1, x2,...)`, returns another function `g([x1, x2,...], dy?)`, + * which gives an array of gradients of `f()` with respect to each input + * [`x1`,`x2`,...]. + * + * If `dy` is passed when calling `g()`, the gradient of + * `f(x1,...).mul(dy).sum()` with respect to each input is computed instead. + * The provided `f` must take one or more tensors and return a single tensor + * `y`. If `f()` takes a single input, we recommend using `tf.grad` instead. + * + * ```js + * // f(a, b) = a * b + * const f = (a, b) => a.mul(b); + * // df / da = b, df / db = a + * const g = tf.grads(f); + * + * const a = tf.tensor1d([2, 3]); + * const b = tf.tensor1d([-2, -3]); + * const [da, db] = g([a, b]); + * console.log('da'); + * da.print(); + * console.log('db'); + * db.print(); + * ``` + * + * @param f The function `f(x1, x2,...)` to compute gradients for. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function grads(f) { + assert(isFunction(f), function () { return 'The f passed in grads(f) must be a function'; }); + return function (args, dy) { + assert(Array.isArray(args), function () { return 'The args passed in grads(f)(args) must be an array ' + + 'of `Tensor`s or `TensorLike`s'; }); + // args can be of any dtype, thus null as the last argument. + var $args = convertToTensorArray(args, 'args', 'tf.grads', null); + var $dy = (dy != null) ? convertToTensor(dy, 'dy', 'tf.grads') : null; + return ENGINE.tidy(function () { + var _a = ENGINE.gradients(function () { return f.apply(void 0, $args); }, $args, $dy), value = _a.value, grads = _a.grads; + if ($dy != null) { + assertShapesMatch(value.shape, $dy.shape, 'The shape of dy passed in grads(f)([x1,...], dy) must ' + + 'match the shape returned by f([x1,...])'); + } + checkGrads(grads); + return grads; + }); + }; + } + /** + * Like `tf.grad`, but also returns the value of `f()`. Useful when `f()` + * returns a metric you want to show. + * + * The result is a rich object with the following properties: + * - grad: The gradient of `f(x)` w.r.t `x` (result of `tf.grad`). + * - value: The value returned by `f(x)`. + * + * ```js + * // f(x) = x ^ 2 + * const f = x => x.square(); + * // f'(x) = 2x + * const g = tf.valueAndGrad(f); + * + * const x = tf.tensor1d([2, 3]); + * const {value, grad} = g(x); + * + * console.log('value'); + * value.print(); + * console.log('grad'); + * grad.print(); + * ``` + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function valueAndGrad(f) { + assert(isFunction(f), function () { return 'The f passed in valueAndGrad(f) must be a function'; }); + return function (x, dy) { + assert(x instanceof Tensor, function () { return 'The x passed in valueAndGrad(f)(x) must be a tensor'; }); + assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrad(f)(x, dy) must be a tensor'; }); + var _a = ENGINE.gradients(function () { return f(x); }, [x], dy), grads = _a.grads, value = _a.value; + checkGrads(grads); + return { grad: grads[0], value: value }; + }; + } + /** + * Like `tf.grads`, but returns also the value of `f()`. Useful when `f()` + * returns a metric you want to show. + * + * The result is a rich object with the following properties: + * - grads: The gradients of `f()` w.r.t each input (result of `tf.grads`). + * - value: The value returned by `f(x)`. + * + * ```js + * // f(a, b) = a * b + * const f = (a, b) => a.mul(b); + * // df/da = b, df/db = a + * const g = tf.valueAndGrads(f); + * + * const a = tf.tensor1d([2, 3]); + * const b = tf.tensor1d([-2, -3]); + * const {value, grads} = g([a, b]); + * + * const [da, db] = grads; + * + * console.log('value'); + * value.print(); + * + * console.log('da'); + * da.print(); + * console.log('db'); + * db.print(); + * ``` + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function valueAndGrads(f) { + assert(isFunction(f), function () { return 'The f passed in valueAndGrads(f) must be a function'; }); + return function (args, dy) { + assert(Array.isArray(args) && args.every(function (arg) { return arg instanceof Tensor; }), function () { return 'The args passed in valueAndGrads(f)(args) must be array of ' + + 'tensors'; }); + assert(dy == null || dy instanceof Tensor, function () { return 'The dy passed in valueAndGrads(f)(args, dy) must be a tensor'; }); + var res = ENGINE.gradients(function () { return f.apply(void 0, args); }, args, dy); + if (dy != null) { + assertShapesMatch(res.value.shape, dy.shape, 'The shape of dy passed in valueAndGrads(f)([x1,...], dy) must ' + + 'match the shape returned by f([x1,...])'); + } + checkGrads(res.grads); + return res; + }; + } + /** + * Computes and returns the gradient of f(x) with respect to the list of + * trainable variables provided by `varList`. If no list is provided, it + * defaults to all trainable variables. + * + * ```js + * const a = tf.variable(tf.tensor1d([3, 4])); + * const b = tf.variable(tf.tensor1d([5, 6])); + * const x = tf.tensor1d([1, 2]); + * + * // f(a, b) = a * x ^ 2 + b * x + * const f = () => a.mul(x.square()).add(b.mul(x)).sum(); + * // df/da = x ^ 2, df/db = x + * const {value, grads} = tf.variableGrads(f); + * + * Object.keys(grads).forEach(varName => grads[varName].print()); + * ``` + * + * @param f The function to execute. f() should return a scalar. + * @param varList The list of variables to compute the gradients with respect + * to. Defaults to all trainable variables. + * @returns An object with the following keys and values: + * - `value`: The value of the function `f`. + * - `grads`: A map from the names of the variables to the gradients. + * If the `varList` argument is provided explicitly and contains a subset of + * non-trainable variables, this map in the return value will contain keys + * that map the names of the non-trainable variables to `null`. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function variableGrads(f, varList) { + assert(isFunction(f), function () { return 'The f passed in variableGrads(f) must be a function'; }); + assert(varList == null || + Array.isArray(varList) && varList.every(function (v) { return v instanceof Variable; }), function () { + return 'The varList passed in variableGrads(f, varList) must be an array ' + + 'of variables'; + }); + var specifiedVarList = varList != null; + if (!specifiedVarList) { + // Get all of the trainable variables. + varList = []; + for (var varName in ENGINE.registeredVariables) { + varList.push(ENGINE.registeredVariables[varName]); + } + } + var specifiedNonTrainable = specifiedVarList ? varList.filter(function (variable) { return !variable.trainable; }) : null; + // Prune non-trainable variables. + var originalVarCount = varList.length; + varList = varList.filter(function (variable) { return variable.trainable; }); + assert(varList.length > 0, function () { return "variableGrads() expects at least one of the input variables to " + + ("be trainable, but none of the " + originalVarCount + " variables is ") + + "trainable."; }); + var allowNoGradients = true; + var _a = ENGINE.gradients(f, varList, null, allowNoGradients), value = _a.value, grads = _a.grads; + assert(grads.some(function (g) { return g != null; }), function () { return 'Cannot find a connection between any variable and the result of ' + + 'the loss function y=f(x). Please make sure the operations that ' + + 'use variables are inside the function f passed to minimize().'; }); + assert(value.rank === 0, function () { return "The f passed in variableGrads(f) must return a scalar, but it " + + ("returned a rank-" + value.rank + " tensor"); }); + var namedGrads = {}; + varList.forEach(function (v, i) { + if (grads[i] != null) { + namedGrads[v.name] = grads[i]; + } + }); + if (specifiedNonTrainable != null) { + // If varList is explicitly provided and contains non-trainable values, + // add them to the returned gradients with `null` values. + specifiedNonTrainable.forEach(function (v) { return namedGrads[v.name] = null; }); + } + return { value: value, grads: namedGrads }; + } + /** + * Overrides the gradient computation of a function `f`. + * + * Takes a function + * `f(...inputs, save) => {value: Tensor, gradFunc: (dy, saved) => Tensor[]}` + * and returns another function `g(...inputs)` which takes the same inputs as + * `f`. When called, `g` returns `f().value`. In backward mode, custom gradients + * with respect to each input of `f` are computed using `f().gradFunc`. + * + * The `save` function passsed to `f` should be used for saving tensors needed + * in the gradient. And the `saved` passed to the `gradFunc` is a + * `NamedTensorMap`, which contains those saved tensor. + * + * ```js + * const customOp = tf.customGrad((x, save) => { + * // Save x to make sure it's available later for the gradient. + * save([x]); + * // Override gradient of our custom x ^ 2 op to be dy * abs(x); + * return { + * value: x.square(), + * // Note `saved.x` which points to the `x` we saved earlier. + * gradFunc: (dy, saved) => [dy.mul(saved[0].abs())] + * }; + * }); + * + * const x = tf.tensor1d([-1, -2, 3]); + * const dx = tf.grad(x => customOp(x)); + * + * console.log(`f(x):`); + * customOp(x).print(); + * console.log(`f'(x):`); + * dx(x).print(); + * ``` + * + * @param f The function to evaluate in forward mode, which should return + * `{value: Tensor, gradFunc: (dy, saved) => Tensor[]}`, where `gradFunc` + * returns the custom gradients of `f` with respect to its inputs. + */ + /** @doc {heading: 'Training', subheading: 'Gradients'} */ + function customGrad(f) { + return ENGINE.customGrad(f); + } + function checkGrads(grads) { + var numNullGradients = grads.filter(function (g) { return g == null; }).length; + if (numNullGradients > 0) { + throw new Error("Cannot compute gradient of y=f(x) with respect to x. Make sure that\n the f you passed encloses all operations that lead from x to y."); + } + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the softmax normalized vector given the logits. + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * + * a.softmax().print(); // or tf.softmax(a) + * ``` + * + * ```js + * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]); + * + * a.softmax().print(); // or tf.softmax(a) + * ``` + * + * @param logits The logits array. + * @param dim The dimension softmax would be performed on. Defaults to `-1` + * which indicates the last dimension. + */ + /** @doc {heading: 'Operations', subheading: 'Normalization'} */ + function softmax_(logits, dim) { + if (dim === void 0) { dim = -1; } + var $logits = convertToTensor(logits, 'logits', 'softmax'); + if (dim === -1) { + dim = $logits.rank - 1; + } + if (dim !== $logits.rank - 1) { + throw Error('Softmax along a non-last dimension is not yet supported. ' + + ("Logits was rank " + $logits.rank + " and dim was " + dim)); + } + var customOp = customGrad(function (logits, save) { + // Do it in log space for numerical stability. + // exp(X - logSumExp(X)) + var keepDims = true; + var lse = logits.logSumExp([dim], keepDims); + var logResult = logits.toFloat().sub(lse); + var y = logResult.exp(); + save([y]); + var gradFunc = function (dy, saved) { + var y = saved[0]; + var dyTimesY = dy.mul(y); + var keepDims = true; + return dyTimesY.sub(dyTimesY.sum([dim], keepDims).mul(y)); + }; + return { value: y, gradFunc: gradFunc }; + }); + return customOp($logits); + } + /** + * Computes the log softmax. + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * + * a.logSoftmax().print(); // or tf.logSoftmax(a) + * ``` + * + * ```js + * const a = tf.tensor2d([2, 4, 6, 1, 2, 3], [2, 3]); + * + * a.logSoftmax().print(); // or tf.logSoftmax(a) + * ``` + * + * @param logits The logits array. + * @param axis The dimension softmax would be performed on. Defaults to `-1` + * which indicates the last dimension. + */ + /** @doc {heading: 'Operations', subheading: 'Normalization'} */ + function logSoftmax_(logits, axis) { + if (axis === void 0) { axis = -1; } + var $logits = convertToTensor(logits, 'logits', 'logSoftmax'); + if (axis === -1) { + axis = $logits.rank - 1; + } + if (axis !== $logits.rank - 1) { + throw Error('Log Softmax along a non-last dimension is not yet supported. ' + + ("Logits was rank " + $logits.rank + " and axis was " + axis)); + } + var customOp = customGrad(function (logits, save) { + var keepDims = true; + var xMax = logits.max(axis, true); + var shifted = logits.sub(xMax); + var value = shifted.toFloat().sub(shifted.exp().sum(axis, keepDims).log()); + save([value]); + var gradFunc = function (dy, saved) { + var value = saved[0]; + var softmax = value.exp(); + return dy.sub(dy.sum(axis, keepDims).mul(softmax)); + }; + return { value: value, gradFunc: gradFunc }; + }); + return customOp($logits); + } + var softmax = op({ softmax_: softmax_ }); + var logSoftmax = op({ logSoftmax_: logSoftmax_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EPSILON_FLOAT32 = 1e-7; + var EPSILON_FLOAT16 = 1e-4; + /** Convenient class for storing tensor-related data. */ + var DataStorage = /** @class */ (function () { + function DataStorage(backend, dataMover) { + this.backend = backend; + this.dataMover = dataMover; + this.data = new WeakMap(); + this.dataIdsCount = 0; + } + DataStorage.prototype.get = function (dataId) { + if (!this.data.has(dataId)) { + this.dataMover.moveData(this.backend, dataId); + } + return this.data.get(dataId); + }; + DataStorage.prototype.set = function (dataId, value) { + this.dataIdsCount++; + this.data.set(dataId, value); + }; + DataStorage.prototype.has = function (dataId) { + return this.data.has(dataId); + }; + DataStorage.prototype.delete = function (dataId) { + this.dataIdsCount--; + return this.data.delete(dataId); + }; + DataStorage.prototype.numDataIds = function () { + return this.dataIdsCount; + }; + return DataStorage; + }()); + /** + * The interface that defines the kernels that should be implemented when + * adding a new backend. New backends don't need to implement every one of the + * methods, this can be done gradually (throw an error for unimplemented + * methods). + */ + var KernelBackend = /** @class */ (function () { + function KernelBackend() { + } + KernelBackend.prototype.time = function (f) { + return notYetImplemented('time'); + }; + KernelBackend.prototype.read = function (dataId) { + return notYetImplemented('read'); + }; + KernelBackend.prototype.readSync = function (dataId) { + return notYetImplemented('readSync'); + }; + KernelBackend.prototype.numDataIds = function () { + return notYetImplemented('numDataIds'); + }; + KernelBackend.prototype.disposeData = function (dataId) { + return notYetImplemented('disposeData'); + }; + KernelBackend.prototype.fromPixels = function (pixels, numChannels) { + return notYetImplemented('fromPixels'); + }; + KernelBackend.prototype.write = function (values, shape, dtype) { + return notYetImplemented('write'); + }; + KernelBackend.prototype.move = function (dataId, values, shape, dtype) { + return notYetImplemented('move'); + }; + KernelBackend.prototype.memory = function () { + return notYetImplemented('memory'); + }; + /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ + KernelBackend.prototype.floatPrecision = function () { + return notYetImplemented('floatPrecision'); + }; + /** Returns the smallest representable number. */ + KernelBackend.prototype.epsilon = function () { + return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; + }; + KernelBackend.prototype.batchMatMul = function (a, b, transposeA, transposeB) { + return notYetImplemented('batchMatMul'); + }; + KernelBackend.prototype.fusedBatchMatMul = function (_a) { + var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + return notYetImplemented('fusedBatchMatMul'); + }; + KernelBackend.prototype.slice = function (x, begin, size) { + return notYetImplemented('slice'); + }; + KernelBackend.prototype.stridedSlice = function (x, begin, end, strides) { + return notYetImplemented('stridedSlice'); + }; + KernelBackend.prototype.unstack = function (x, axis) { + return notYetImplemented('unstack'); + }; + KernelBackend.prototype.reverse = function (a, axis) { + return notYetImplemented('reverse'); + }; + KernelBackend.prototype.concat = function (tensors, axis) { + return notYetImplemented('concat'); + }; + KernelBackend.prototype.neg = function (a) { + return notYetImplemented('neg'); + }; + KernelBackend.prototype.add = function (a, b) { + return notYetImplemented('add'); + }; + KernelBackend.prototype.addN = function (tensors) { + return notYetImplemented('addN'); + }; + KernelBackend.prototype.subtract = function (a, b) { + return notYetImplemented('subtract'); + }; + KernelBackend.prototype.multiply = function (a, b) { + return notYetImplemented('multiply'); + }; + KernelBackend.prototype.realDivide = function (a, b) { + return notYetImplemented('realDivide'); + }; + KernelBackend.prototype.floorDiv = function (a, b) { + return notYetImplemented('floorDiv'); + }; + KernelBackend.prototype.sum = function (x, axes) { + return notYetImplemented('sum'); + }; + KernelBackend.prototype.prod = function (x, axes) { + return notYetImplemented('prod'); + }; + KernelBackend.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) { + return notYetImplemented('unsortedSegmentSum'); + }; + KernelBackend.prototype.argMin = function (x, axis) { + return notYetImplemented('argMin'); + }; + KernelBackend.prototype.argMax = function (x, axis) { + return notYetImplemented('argMax'); + }; + KernelBackend.prototype.equal = function (a, b) { + return notYetImplemented('equal'); + }; + KernelBackend.prototype.notEqual = function (a, b) { + return notYetImplemented('notEqual'); + }; + KernelBackend.prototype.less = function (a, b) { + return notYetImplemented('less'); + }; + KernelBackend.prototype.lessEqual = function (a, b) { + return notYetImplemented('lessEqual'); + }; + KernelBackend.prototype.greater = function (a, b) { + return notYetImplemented('greater'); + }; + KernelBackend.prototype.greaterEqual = function (a, b) { + return notYetImplemented('greaterEqual'); + }; + KernelBackend.prototype.logicalNot = function (a) { + return notYetImplemented('logicalNot'); + }; + KernelBackend.prototype.logicalAnd = function (a, b) { + return notYetImplemented('logicalAnd'); + }; + KernelBackend.prototype.logicalOr = function (a, b) { + return notYetImplemented('logicalOr'); + }; + KernelBackend.prototype.where = function (condition) { + return notYetImplemented('where'); + }; + KernelBackend.prototype.select = function (condition, a, b) { + return notYetImplemented('select'); + }; + KernelBackend.prototype.topk = function (x, k, sorted) { + return notYetImplemented('topk'); + }; + KernelBackend.prototype.min = function (x, axes) { + return notYetImplemented('min'); + }; + KernelBackend.prototype.minimum = function (a, b) { + return notYetImplemented('minimum'); + }; + KernelBackend.prototype.mod = function (a, b) { + return notYetImplemented('mod'); + }; + KernelBackend.prototype.max = function (x, axes) { + return notYetImplemented('max'); + }; + KernelBackend.prototype.maximum = function (a, b) { + return notYetImplemented('maximum'); + }; + KernelBackend.prototype.all = function (x, axes) { + return notYetImplemented('all'); + }; + KernelBackend.prototype.any = function (x, axes) { + return notYetImplemented('any'); + }; + KernelBackend.prototype.squaredDifference = function (a, b) { + return notYetImplemented('squaredDifference'); + }; + KernelBackend.prototype.ceil = function (x) { + return notYetImplemented('ceil'); + }; + KernelBackend.prototype.floor = function (x) { + return notYetImplemented('floor'); + }; + KernelBackend.prototype.round = function (x) { + return notYetImplemented('round'); + }; + KernelBackend.prototype.sign = function (x) { + return notYetImplemented('sign'); + }; + KernelBackend.prototype.isNaN = function (x) { + return notYetImplemented('isNaN'); + }; + KernelBackend.prototype.isInf = function (x) { + return notYetImplemented('isInf'); + }; + KernelBackend.prototype.isFinite = function (x) { + return notYetImplemented('isFinite'); + }; + KernelBackend.prototype.pow = function (a, b) { + return notYetImplemented('pow'); + }; + KernelBackend.prototype.exp = function (x) { + return notYetImplemented('exp'); + }; + KernelBackend.prototype.expm1 = function (x) { + return notYetImplemented('expm1'); + }; + KernelBackend.prototype.log = function (x) { + return notYetImplemented('log'); + }; + KernelBackend.prototype.log1p = function (x) { + return notYetImplemented('log1p'); + }; + KernelBackend.prototype.sqrt = function (x) { + return notYetImplemented('sqrt'); + }; + KernelBackend.prototype.rsqrt = function (x) { + return notYetImplemented('rsqrt'); + }; + KernelBackend.prototype.square = function (x) { + return notYetImplemented('square'); + }; + KernelBackend.prototype.reciprocal = function (x) { + return notYetImplemented('reciprocal'); + }; + KernelBackend.prototype.relu = function (x) { + return notYetImplemented('relu'); + }; + KernelBackend.prototype.relu6 = function (x) { + return notYetImplemented('relu6'); + }; + KernelBackend.prototype.prelu = function (x, a) { + return notYetImplemented('prelu'); + }; + KernelBackend.prototype.elu = function (x) { + return notYetImplemented('elu'); + }; + KernelBackend.prototype.eluDer = function (dy, y) { + return notYetImplemented('eluDer'); + }; + KernelBackend.prototype.selu = function (x) { + return notYetImplemented('selu'); + }; + KernelBackend.prototype.int = function (x) { + return notYetImplemented('int'); + }; + KernelBackend.prototype.clip = function (x, min, max) { + return notYetImplemented('clip'); + }; + KernelBackend.prototype.abs = function (x) { + return notYetImplemented('abs'); + }; + KernelBackend.prototype.complexAbs = function (x) { + return notYetImplemented('complexAbs'); + }; + KernelBackend.prototype.sigmoid = function (x) { + return notYetImplemented('sigmoid'); + }; + KernelBackend.prototype.softplus = function (x) { + return notYetImplemented('softplus'); + }; + KernelBackend.prototype.sin = function (x) { + return notYetImplemented('sin'); + }; + KernelBackend.prototype.cos = function (x) { + return notYetImplemented('cos'); + }; + KernelBackend.prototype.tan = function (x) { + return notYetImplemented('tan'); + }; + KernelBackend.prototype.asin = function (x) { + return notYetImplemented('asin'); + }; + KernelBackend.prototype.acos = function (x) { + return notYetImplemented('acos'); + }; + KernelBackend.prototype.atan = function (x) { + return notYetImplemented('atan'); + }; + KernelBackend.prototype.atan2 = function (a, b) { + return notYetImplemented('atan2'); + }; + KernelBackend.prototype.sinh = function (x) { + return notYetImplemented('sinh'); + }; + KernelBackend.prototype.cosh = function (x) { + return notYetImplemented('cosh'); + }; + KernelBackend.prototype.tanh = function (x) { + return notYetImplemented('tanh'); + }; + KernelBackend.prototype.asinh = function (x) { + return notYetImplemented('asinh'); + }; + KernelBackend.prototype.acosh = function (x) { + return notYetImplemented('acosh'); + }; + KernelBackend.prototype.atanh = function (x) { + return notYetImplemented('atanh'); + }; + KernelBackend.prototype.erf = function (x) { + return notYetImplemented('erf'); + }; + KernelBackend.prototype.step = function (x, alpha) { + return notYetImplemented('step'); + }; + KernelBackend.prototype.fusedConv2d = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + return notYetImplemented('fusedConv2d'); + }; + KernelBackend.prototype.conv2d = function (x, filter, convInfo) { + return notYetImplemented('conv2d'); + }; + KernelBackend.prototype.conv2dDerInput = function (dy, filter, convInfo) { + return notYetImplemented('conv2dDerInput'); + }; + KernelBackend.prototype.conv2dDerFilter = function (x, dY, convInfo) { + return notYetImplemented('conv2dDerFilter'); + }; + KernelBackend.prototype.fusedDepthwiseConv2D = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + return notYetImplemented('fusedDepthwiseConv2D'); + }; + KernelBackend.prototype.depthwiseConv2D = function (input, filter, convInfo) { + return notYetImplemented('depthwiseConv2D'); + }; + KernelBackend.prototype.depthwiseConv2DDerInput = function (dy, filter, convInfo) { + return notYetImplemented('depthwiseConv2DDerInput'); + }; + KernelBackend.prototype.depthwiseConv2DDerFilter = function (x, dY, convInfo) { + return notYetImplemented('depthwiseConv2DDerFilter'); + }; + KernelBackend.prototype.conv3d = function (x, filter, convInfo) { + return notYetImplemented('conv3d'); + }; + KernelBackend.prototype.conv3dDerInput = function (dy, filter, convInfo) { + return notYetImplemented('conv3dDerInput'); + }; + KernelBackend.prototype.conv3dDerFilter = function (x, dY, convInfo) { + return notYetImplemented('conv3dDerFilter'); + }; + KernelBackend.prototype.maxPool = function (x, convInfo) { + return notYetImplemented('maxPool'); + }; + KernelBackend.prototype.maxPoolBackprop = function (dy, x, y, convInfo) { + return notYetImplemented('maxPoolBackprop'); + }; + KernelBackend.prototype.avgPool = function (x, convInfo) { + return notYetImplemented('avgPool'); + }; + KernelBackend.prototype.avgPoolBackprop = function (dy, x, convInfo) { + return notYetImplemented('avgPoolBackprop'); + }; + KernelBackend.prototype.avgPool3d = function (x, convInfo) { + return notYetImplemented('avgPool3d'); + }; + KernelBackend.prototype.avgPool3dBackprop = function (dy, x, convInfo) { + return notYetImplemented('avgPool3dBackprop'); + }; + KernelBackend.prototype.maxPool3d = function (x, convInfo) { + return notYetImplemented('maxPool3d'); + }; + KernelBackend.prototype.maxPool3dBackprop = function (dy, x, y, convInfo) { + return notYetImplemented('maxPool3dBackprop'); + }; + KernelBackend.prototype.reshape = function (x, shape) { + return notYetImplemented('reshape'); + }; + KernelBackend.prototype.cast = function (x, dtype) { + return notYetImplemented('cast'); + }; + KernelBackend.prototype.tile = function (x, reps) { + return notYetImplemented('tile'); + }; + KernelBackend.prototype.pad = function (x, paddings, constantValue) { + return notYetImplemented('pad'); + }; + KernelBackend.prototype.transpose = function (x, perm) { + return notYetImplemented('transpose'); + }; + KernelBackend.prototype.gather = function (x, indices, axis) { + return notYetImplemented('gather'); + }; + KernelBackend.prototype.gatherND = function (x, indices) { + return notYetImplemented('gatherND'); + }; + KernelBackend.prototype.scatterND = function (indices, updates, shape) { + return notYetImplemented('scatterND'); + }; + KernelBackend.prototype.batchToSpaceND = function (x, blockShape, crops) { + return notYetImplemented('batchToSpaceND'); + }; + KernelBackend.prototype.spaceToBatchND = function (x, blockShape, paddings) { + return notYetImplemented('spaceToBatchND'); + }; + KernelBackend.prototype.resizeBilinear = function (x, newHeight, newWidth, alignCorners) { + return notYetImplemented('resizeBilinear'); + }; + KernelBackend.prototype.resizeBilinearBackprop = function (dy, x, alignCorners) { + return notYetImplemented('resizeBilinearBackprop'); + }; + KernelBackend.prototype.resizeNearestNeighbor = function (x, newHEight, newWidth, alignCorners) { + return notYetImplemented('resizeNearestNeighbor'); + }; + KernelBackend.prototype.resizeNearestNeighborBackprop = function (dy, x, alignCorners) { + return notYetImplemented('resizeNearestNeighborBackprop'); + }; + KernelBackend.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) { + return notYetImplemented('batchNormalization'); + }; + KernelBackend.prototype.localResponseNormalization4D = function (x, radius, bias, alpha, beta) { + return notYetImplemented('localResponseNormalization4D'); + }; + KernelBackend.prototype.LRNGrad = function (dy, inputImage, outputImage, radius, bias, alpha, beta) { + return notYetImplemented('LRNGrad'); + }; + KernelBackend.prototype.multinomial = function (logits, normalized, numSamples, seed) { + return notYetImplemented('multinomial'); + }; + KernelBackend.prototype.oneHot = function (indices, depth, onValue, offValue) { + return notYetImplemented('oneHot'); + }; + KernelBackend.prototype.cumsum = function (x, axis, exclusive, reverse) { + return notYetImplemented('cumsum'); + }; + KernelBackend.prototype.nonMaxSuppression = function (boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + return notYetImplemented('nonMaxSuppression'); + }; + KernelBackend.prototype.fft = function (x) { + return notYetImplemented('fft'); + }; + KernelBackend.prototype.ifft = function (x) { + return notYetImplemented('ifft'); + }; + KernelBackend.prototype.complex = function (real, imag) { + return notYetImplemented('complex'); + }; + KernelBackend.prototype.real = function (input) { + return notYetImplemented('real'); + }; + KernelBackend.prototype.imag = function (input) { + return notYetImplemented('imag'); + }; + KernelBackend.prototype.cropAndResize = function (image, boxes, boxIndex, cropSize, method, extrapolationValue) { + return notYetImplemented('cropAndResize'); + }; + KernelBackend.prototype.depthToSpace = function (x, blockSize, dataFormat) { + return notYetImplemented('depthToSpace'); + }; + // Aligns with the "SplitV" kernel in TensorFlow. + KernelBackend.prototype.split = function (value, sizeSplits, axis) { + return notYetImplemented('split'); + }; + KernelBackend.prototype.sparseToDense = function (sparseIndices, sparseValues, outputShape, defaultValue) { + return notYetImplemented('sparseToDense'); + }; + KernelBackend.prototype.diag = function (x) { + return notYetImplemented('diag'); + }; + KernelBackend.prototype.fill = function (shape, value, dtype) { + return notYetImplemented('fill'); + }; + KernelBackend.prototype.onesLike = function (x) { + return notYetImplemented('onesLike'); + }; + KernelBackend.prototype.zerosLike = function (x) { + return notYetImplemented('zerosLike'); + }; + KernelBackend.prototype.linspace = function (start, stop, num) { + return notYetImplemented('linspace'); + }; + KernelBackend.prototype.dispose = function () { + return notYetImplemented('dispose'); + }; + return KernelBackend; + }()); + function notYetImplemented(kernelName) { + throw new Error("'" + kernelName + "' not yet implemented or not found in the registry. " + + "Did you forget to import the kernel?"); + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns the dimensions in the input shape that are broadcasted to + * produce the provided output shape. + * + * The returned dimensions are 0-indexed and sorted. An example: + * inShape = [4, 1, 3] + * outShape = [5, 4, 3, 3] + * result = [1]. Dimension 1 (2nd dimension of input) gets broadcasted 1 => 3. + */ + function getBroadcastDims(inShape, outShape) { + var inRank = inShape.length; + var dims = []; + for (var i = 0; i < inRank; i++) { + var dim = inRank - 1 - i; + var a = inShape[dim] || 1; + var b = outShape[outShape.length - 1 - i] || 1; + if (b > 1 && a === 1) { + dims.unshift(dim); + } + } + return dims; + } + /** + * Returns the axes in the output space that should be reduced to produce + * the input space. + */ + function getReductionAxes(inShape, outShape) { + var result = []; + for (var i = 0; i < outShape.length; i++) { + var inDim = inShape[inShape.length - i - 1]; + var outAxis = outShape.length - i - 1; + var outDim = outShape[outAxis]; + if (inDim == null || (inDim === 1 && outDim > 1)) { + result.unshift(outAxis); + } + } + return result; + } + function assertAndGetBroadcastShape(shapeA, shapeB) { + var result = []; + var l = Math.max(shapeA.length, shapeB.length); + for (var i = 0; i < l; i++) { + var a = shapeA[shapeA.length - i - 1]; + if (a == null) { + a = 1; + } + var b = shapeB[shapeB.length - i - 1]; + if (b == null) { + b = 1; + } + if (a === 1) { + result.unshift(b); + } + else if (b === 1) { + result.unshift(a); + } + else if (a !== b) { + var errMsg = "Operands could not be broadcast together with shapes " + + (shapeA + " and " + shapeB + "."); + throw Error(errMsg); + } + else { + result.unshift(a); + } + } + return result; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function computePool2DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { + if (dataFormat === void 0) { dataFormat = 'channelsLast'; } + var _a = parseTupleParam(filterSize), filterHeight = _a[0], filterWidth = _a[1]; + var filterShape; + if (dataFormat === 'channelsLast') { + filterShape = [filterHeight, filterWidth, inShape[3], inShape[3]]; + } + else if (dataFormat === 'channelsFirst') { + filterShape = [filterHeight, filterWidth, inShape[1], inShape[1]]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + return computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, false, dataFormat); + } + /** + * Computes the information for a forward pass of a pooling3D operation. + */ + function computePool3DInfo(inShape, filterSize, strides, dilations, pad, roundingMode, dataFormat) { + if (dataFormat === void 0) { dataFormat = 'NDHWC'; } + var _a = parse3TupleParam(filterSize), filterDepth = _a[0], filterHeight = _a[1], filterWidth = _a[2]; + var filterShape; + var $dataFormat; + if (dataFormat === 'NDHWC') { + $dataFormat = 'channelsLast'; + filterShape = + [filterDepth, filterHeight, filterWidth, inShape[4], inShape[4]]; + } + else if (dataFormat === 'NCDHW') { + $dataFormat = 'channelsFirst'; + filterShape = + [filterDepth, filterHeight, filterWidth, inShape[1], inShape[1]]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + return computeConv3DInfo(inShape, filterShape, strides, dilations, pad, false, $dataFormat, roundingMode); + } + /** + * Computes the information for a forward pass of a convolution/pooling + * operation. + */ + function computeConv2DInfo(inShape, filterShape, strides, dilations, pad, roundingMode, depthwise, dataFormat) { + if (depthwise === void 0) { depthwise = false; } + if (dataFormat === void 0) { dataFormat = 'channelsLast'; } + var _a = [-1, -1, -1, -1], batchSize = _a[0], inHeight = _a[1], inWidth = _a[2], inChannels = _a[3]; + if (dataFormat === 'channelsLast') { + batchSize = inShape[0], inHeight = inShape[1], inWidth = inShape[2], inChannels = inShape[3]; + } + else if (dataFormat === 'channelsFirst') { + batchSize = inShape[0], inChannels = inShape[1], inHeight = inShape[2], inWidth = inShape[3]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + var filterHeight = filterShape[0], filterWidth = filterShape[1], filterChannels = filterShape[3]; + var _b = parseTupleParam(strides), strideHeight = _b[0], strideWidth = _b[1]; + var _c = parseTupleParam(dilations), dilationHeight = _c[0], dilationWidth = _c[1]; + var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); + var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); + var _d = getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outHeight = _d.outHeight, outWidth = _d.outWidth; + var outChannels = depthwise ? filterChannels * inChannels : filterChannels; + var outShape; + if (dataFormat === 'channelsFirst') { + outShape = [batchSize, outChannels, outHeight, outWidth]; + } + else if (dataFormat === 'channelsLast') { + outShape = [batchSize, outHeight, outWidth, outChannels]; + } + return { + batchSize: batchSize, + dataFormat: dataFormat, + inHeight: inHeight, + inWidth: inWidth, + inChannels: inChannels, + outHeight: outHeight, + outWidth: outWidth, + outChannels: outChannels, + padInfo: padInfo, + strideHeight: strideHeight, + strideWidth: strideWidth, + filterHeight: filterHeight, + filterWidth: filterWidth, + effectiveFilterHeight: effectiveFilterHeight, + effectiveFilterWidth: effectiveFilterWidth, + dilationHeight: dilationHeight, + dilationWidth: dilationWidth, + inShape: inShape, + outShape: outShape, + filterShape: filterShape + }; + } + /** + * Computes the information for a forward pass of a 3D convolution/pooling + * operation. + */ + function computeConv3DInfo(inShape, filterShape, strides, dilations, pad, depthwise, dataFormat, roundingMode) { + if (depthwise === void 0) { depthwise = false; } + if (dataFormat === void 0) { dataFormat = 'channelsLast'; } + var _a = [-1, -1, -1, -1, -1], batchSize = _a[0], inDepth = _a[1], inHeight = _a[2], inWidth = _a[3], inChannels = _a[4]; + if (dataFormat === 'channelsLast') { + batchSize = inShape[0], inDepth = inShape[1], inHeight = inShape[2], inWidth = inShape[3], inChannels = inShape[4]; + } + else if (dataFormat === 'channelsFirst') { + batchSize = inShape[0], inChannels = inShape[1], inDepth = inShape[2], inHeight = inShape[3], inWidth = inShape[4]; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + var filterDepth = filterShape[0], filterHeight = filterShape[1], filterWidth = filterShape[2], filterChannels = filterShape[4]; + var _b = parse3TupleParam(strides), strideDepth = _b[0], strideHeight = _b[1], strideWidth = _b[2]; + var _c = parse3TupleParam(dilations), dilationDepth = _c[0], dilationHeight = _c[1], dilationWidth = _c[2]; + var effectiveFilterDepth = getEffectiveFilterSize(filterDepth, dilationDepth); + var effectiveFilterHeight = getEffectiveFilterSize(filterHeight, dilationHeight); + var effectiveFilterWidth = getEffectiveFilterSize(filterWidth, dilationWidth); + var _d = get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, effectiveFilterDepth, effectiveFilterHeight, effectiveFilterWidth, roundingMode), padInfo = _d.padInfo, outDepth = _d.outDepth, outHeight = _d.outHeight, outWidth = _d.outWidth; + var outChannels = depthwise ? filterChannels * inChannels : filterChannels; + var outShape; + if (dataFormat === 'channelsFirst') { + outShape = [batchSize, outChannels, outDepth, outHeight, outWidth]; + } + else if (dataFormat === 'channelsLast') { + outShape = [batchSize, outDepth, outHeight, outWidth, outChannels]; + } + return { + batchSize: batchSize, + dataFormat: dataFormat, + inDepth: inDepth, + inHeight: inHeight, + inWidth: inWidth, + inChannels: inChannels, + outDepth: outDepth, + outHeight: outHeight, + outWidth: outWidth, + outChannels: outChannels, + padInfo: padInfo, + strideDepth: strideDepth, + strideHeight: strideHeight, + strideWidth: strideWidth, + filterDepth: filterDepth, + filterHeight: filterHeight, + filterWidth: filterWidth, + effectiveFilterDepth: effectiveFilterDepth, + effectiveFilterHeight: effectiveFilterHeight, + effectiveFilterWidth: effectiveFilterWidth, + dilationDepth: dilationDepth, + dilationHeight: dilationHeight, + dilationWidth: dilationWidth, + inShape: inShape, + outShape: outShape, + filterShape: filterShape + }; + } + function computeOutputShape2D(inShape, fieldSize, stride, zeroPad, roundingMode) { + if (zeroPad == null) { + zeroPad = computeDefaultPad(inShape, fieldSize, stride); + } + var inputRows = inShape[0]; + var inputCols = inShape[1]; + var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + return [outputRows, outputCols]; + } + function computeOutputShape4D(inShape, fieldSize, outChannels, stride, zeroPad, roundingMode) { + if (zeroPad == null) { + zeroPad = computeDefaultPad(inShape, fieldSize, stride); + } + var inputDepth = inShape[0]; + var inputRows = inShape[1]; + var inputCols = inShape[2]; + var outputDepths = conditionalRound((inputDepth - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputDepths), function () { return "The output # of depths (" + outputDepths + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + var outputRows = conditionalRound((inputRows - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputRows), function () { return "The output # of rows (" + outputRows + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + var outputCols = conditionalRound((inputCols - fieldSize + 2 * zeroPad) / stride + 1, roundingMode); + assert(isInt(outputCols), function () { return "The output # of columns (" + outputCols + ") must be an integer. " + + "Change the stride and/or zero pad parameters"; }); + return [outputDepths, outputRows, outputCols, outChannels]; + } + function computeDefaultPad(inputShape, fieldSize, stride, dilation) { + if (dilation === void 0) { dilation = 1; } + var effectiveFieldSize = getEffectiveFilterSize(fieldSize, dilation); + return Math.floor((inputShape[0] * (stride - 1) - stride + effectiveFieldSize) / 2); + } + function parseTupleParam(param) { + if (typeof param === 'number') { + return [param, param, param]; + } + if (param.length === 2) { + return [param[0], param[1], 1]; + } + return param; + } + function parse3TupleParam(param) { + return typeof param === 'number' ? [param, param, param] : param; + } + /* See https://www.tensorflow.org/api_docs/python/tf/nn/atrous_conv2d + * Atrous convolution is equivalent to standard convolution with upsampled + * filters with effective_filter_height = + * filter_height + (filter_height - 1) * (dilation - 1) + * and effective_filter_width = + * filter_width + (filter_width - 1) * (dilation - 1), + * produced by inserting dilation - 1 zeros along consecutive elements across + * the filters' spatial dimensions. + * When there is a dilation, this converts a filter dimension to the + * effective filter dimension, so it can be used in a standard convolution. + */ + function getEffectiveFilterSize(filterSize, dilation) { + if (dilation <= 1) { + return filterSize; + } + return filterSize + (filterSize - 1) * (dilation - 1); + } + function getPadAndOutInfo(pad, inHeight, inWidth, strideHeight, strideWidth, filterHeight, filterWidth, roundingMode) { + var padInfo; + var outHeight; + var outWidth; + if (typeof pad === 'number') { + var padType = (pad === 0) ? 'VALID' : 'NUMBER'; + padInfo = { top: pad, bottom: pad, left: pad, right: pad, type: padType }; + var outShape = computeOutputShape2D([inHeight, inWidth], filterHeight, strideHeight, pad, roundingMode); + outHeight = outShape[0]; + outWidth = outShape[1]; + } + else if (pad === 'same') { + outHeight = Math.ceil(inHeight / strideHeight); + outWidth = Math.ceil(inWidth / strideWidth); + var padAlongHeight = Math.max(0, (outHeight - 1) * strideHeight + filterHeight - inHeight); + var padAlongWidth = Math.max(0, (outWidth - 1) * strideWidth + filterWidth - inWidth); + var top_1 = Math.floor(padAlongHeight / 2); + var bottom = padAlongHeight - top_1; + var left = Math.floor(padAlongWidth / 2); + var right = padAlongWidth - left; + padInfo = { top: top_1, bottom: bottom, left: left, right: right, type: 'SAME' }; + } + else if (pad === 'valid') { + padInfo = { top: 0, bottom: 0, left: 0, right: 0, type: 'VALID' }; + outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); + outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); + } + else { + throw Error("Unknown padding parameter: " + pad); + } + return { padInfo: padInfo, outHeight: outHeight, outWidth: outWidth }; + } + function get3DPadAndOutInfo(pad, inDepth, inHeight, inWidth, strideDepth, strideHeight, strideWidth, filterDepth, filterHeight, filterWidth, roundingMode) { + var padInfo; + var outDepth; + var outHeight; + var outWidth; + if (typeof pad === 'number') { + var padType = (pad === 0) ? 'VALID' : 'NUMBER'; + padInfo = { + top: pad, + bottom: pad, + left: pad, + right: pad, + front: pad, + back: pad, + type: padType + }; + var outShape = computeOutputShape4D([inDepth, inHeight, inWidth, 1], filterDepth, 1, strideDepth, pad, roundingMode); + outDepth = outShape[0]; + outHeight = outShape[1]; + outWidth = outShape[2]; + } + else if (pad === 'same') { + outDepth = Math.ceil(inDepth / strideDepth); + outHeight = Math.ceil(inHeight / strideHeight); + outWidth = Math.ceil(inWidth / strideWidth); + var padAlongDepth = (outDepth - 1) * strideDepth + filterDepth - inDepth; + var padAlongHeight = (outHeight - 1) * strideHeight + filterHeight - inHeight; + var padAlongWidth = (outWidth - 1) * strideWidth + filterWidth - inWidth; + var front = Math.floor(padAlongDepth / 2); + var back = padAlongDepth - front; + var top_2 = Math.floor(padAlongHeight / 2); + var bottom = padAlongHeight - top_2; + var left = Math.floor(padAlongWidth / 2); + var right = padAlongWidth - left; + padInfo = { top: top_2, bottom: bottom, left: left, right: right, front: front, back: back, type: 'SAME' }; + } + else if (pad === 'valid') { + padInfo = { + top: 0, + bottom: 0, + left: 0, + right: 0, + front: 0, + back: 0, + type: 'VALID' + }; + outDepth = Math.ceil((inDepth - filterDepth + 1) / strideDepth); + outHeight = Math.ceil((inHeight - filterHeight + 1) / strideHeight); + outWidth = Math.ceil((inWidth - filterWidth + 1) / strideWidth); + } + else { + throw Error("Unknown padding parameter: " + pad); + } + return { padInfo: padInfo, outDepth: outDepth, outHeight: outHeight, outWidth: outWidth }; + } + /** + * Rounds a value depending on the rounding mode + * @param value + * @param roundingMode + */ + function conditionalRound(value, roundingMode) { + if (!roundingMode) { + return value; + } + switch (roundingMode) { + case 'round': + // used for Caffe Conv + return Math.round(value); + case 'ceil': + // used for Caffe Pool + return Math.ceil(value); + case 'floor': + return Math.floor(value); + default: + throw new Error("Unknown roundingMode " + roundingMode); + } + } + function tupleValuesAreOne(param) { + var _a = parseTupleParam(param), dimA = _a[0], dimB = _a[1], dimC = _a[2]; + return dimA === 1 && dimB === 1 && dimC === 1; + } + function eitherStridesOrDilationsAreOne(strides, dilations) { + return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations); + } + /** + * Convert Conv2D dataFormat from 'NHWC'|'NCHW' to + * 'channelsLast'|'channelsFirst' + * @param dataFormat in 'NHWC'|'NCHW' mode + * @return dataFormat in 'channelsLast'|'channelsFirst' mode + * @throws unknown dataFormat + */ + function convertConv2DDataFormat(dataFormat) { + if (dataFormat === 'NHWC') { + return 'channelsLast'; + } + else if (dataFormat === 'NCHW') { + return 'channelsFirst'; + } + else { + throw new Error("Unknown dataFormat " + dataFormat); + } + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function castTensor(x, dtype, backend) { + if (dtype === 'complex64') { + if (x.dtype === 'complex64') { + return x.clone(); + } + var zerosTensor = zeros(x.shape); + var floatX = x.toFloat(); + var result = backend.complex(floatX, zerosTensor); + zerosTensor.dispose(); + floatX.dispose(); + return result; + } + if (!hasEncodingLoss(x.dtype, dtype)) { + // We don't change the underlying data, since we cast to higher + // precision. + return ENGINE.makeTensorFromDataId(x.dataId, x.shape, dtype); + } + if (x.dtype === 'complex64') { + var real = backend.real(x); + var result = real.cast(dtype); + real.dispose(); + return result; + } + if (dtype === 'int32') { + return backend.int(x); + } + else if (dtype === 'bool') { + var zero = scalar(0, x.dtype); + var result = backend.notEqual(x, zero); + zero.dispose(); + return result; + } + else { + throw new Error("Error in Cast: failed to cast " + x.dtype + " to " + dtype); + } + } + function reshapeTensor(x, shape) { + return ENGINE.makeTensorFromDataId(x.dataId, shape, x.dtype); + } + function linspaceImpl(start, stop, num) { + var step = (stop - start) / (num - 1); + var values = makeZerosTypedArray(num, 'float32'); + values[0] = start; + for (var i = 1; i < values.length; i++) { + values[i] = values[i - 1] + step; + } + return tensor1d(values, 'float32'); + } + + var backend_util = /*#__PURE__*/Object.freeze({ + castTensor: castTensor, + reshapeTensor: reshapeTensor, + linspaceImpl: linspaceImpl, + upcastType: upcastType, + axesAreInnerMostDims: axesAreInnerMostDims, + combineLocations: combineLocations, + computeOutAndReduceShapes: computeOutAndReduceShapes, + expandShapeToKeepDim: expandShapeToKeepDim, + assertAxesAreInnerMostDims: assertAxesAreInnerMostDims, + getAxesPermutation: getAxesPermutation, + getUndoAxesPermutation: getUndoAxesPermutation, + getInnerMostAxes: getInnerMostAxes, + getBroadcastDims: getBroadcastDims, + getReductionAxes: getReductionAxes, + assertAndGetBroadcastShape: assertAndGetBroadcastShape, + assertParamsConsistent: assertParamsConsistent, + computeOutShape: computeOutShape, + computePool2DInfo: computePool2DInfo, + computePool3DInfo: computePool3DInfo, + computeConv2DInfo: computeConv2DInfo, + computeConv3DInfo: computeConv3DInfo, + computeDefaultPad: computeDefaultPad, + tupleValuesAreOne: tupleValuesAreOne, + eitherStridesOrDilationsAreOne: eitherStridesOrDilationsAreOne, + convertConv2DDataFormat: convertConv2DDataFormat + }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Merges real and imaginary Float32Arrays into a single complex Float32Array. + * + * The memory layout is interleaved as follows: + * real: [r0, r1, r2] + * imag: [i0, i1, i2] + * complex: [r0, i0, r1, i1, r2, i2] + * + * This is the inverse of splitRealAndImagArrays. + * + * @param real The real values of the complex tensor values. + * @param imag The imag values of the complex tensor values. + * @returns A complex tensor as a Float32Array with merged values. + */ + function mergeRealAndImagArrays(real, imag) { + if (real.length !== imag.length) { + throw new Error("Cannot merge real and imag arrays of different lengths. real:" + + (real.length + ", imag: " + imag.length + ".")); + } + var result = new Float32Array(real.length * 2); + for (var i = 0; i < result.length; i += 2) { + result[i] = real[i / 2]; + result[i + 1] = imag[i / 2]; + } + return result; + } + /** + * Splits a complex Float32Array into real and imag parts. + * + * The memory layout is interleaved as follows: + * complex: [r0, i0, r1, i1, r2, i2] + * real: [r0, r1, r2] + * imag: [i0, i1, i2] + * + * This is the inverse of mergeRealAndImagArrays. + * + * @param complex The complex tensor values. + * @returns An object with real and imag Float32Array components of the complex + * tensor. + */ + function splitRealAndImagArrays(complex) { + var real = new Float32Array(complex.length / 2); + var imag = new Float32Array(complex.length / 2); + for (var i = 0; i < complex.length; i += 2) { + real[i / 2] = complex[i]; + imag[i / 2] = complex[i + 1]; + } + return { real: real, imag: imag }; + } + /** + * Extracts even indexed complex values in the given array. + * @param complex The complex tensor values + */ + function complexWithEvenIndex(complex) { + var len = Math.ceil(complex.length / 4); + var real = new Float32Array(len); + var imag = new Float32Array(len); + for (var i = 0; i < complex.length; i += 4) { + real[Math.floor(i / 4)] = complex[i]; + imag[Math.floor(i / 4)] = complex[i + 1]; + } + return { real: real, imag: imag }; + } + /** + * Extracts odd indexed comple values in the given array. + * @param complex The complex tensor values + */ + function complexWithOddIndex(complex) { + var len = Math.floor(complex.length / 4); + var real = new Float32Array(len); + var imag = new Float32Array(len); + for (var i = 2; i < complex.length; i += 4) { + real[Math.floor(i / 4)] = complex[i]; + imag[Math.floor(i / 4)] = complex[i + 1]; + } + return { real: real, imag: imag }; + } + /** + * Get the map representing a complex value in the given array. + * @param complex The complex tensor values. + * @param index An index of the target complex value. + */ + function getComplexWithIndex(complex, index) { + var real = complex[index * 2]; + var imag = complex[index * 2 + 1]; + return { real: real, imag: imag }; + } + /** + * Insert a given complex value into the TypedArray. + * @param data The array in which the complex value is inserted. + * @param c The complex value to be inserted. + * @param index An index of the target complex value. + */ + function assignToTypedArray(data, real, imag, index) { + data[index * 2] = real; + data[index * 2 + 1] = imag; + } + /** + * Make the list of exponent terms used by FFT. + */ + function exponents(n, inverse) { + var real = new Float32Array(n / 2); + var imag = new Float32Array(n / 2); + for (var i = 0; i < Math.ceil(n / 2); i++) { + var x = (inverse ? 2 : -2) * Math.PI * (i / n); + real[i] = Math.cos(x); + imag[i] = Math.sin(x); + } + return { real: real, imag: imag }; + } + /** + * Make the exponent term used by FFT. + */ + function exponent(k, n, inverse) { + var x = (inverse ? 2 : -2) * Math.PI * (k / n); + var real = Math.cos(x); + var imag = Math.sin(x); + return { real: real, imag: imag }; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function nonMaxSuppressionImpl(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + var candidates = Array.from(scores) + .map(function (score, boxIndex) { return ({ score: score, boxIndex: boxIndex }); }) + .filter(function (c) { return c.score > scoreThreshold; }) + .sort(function (c1, c2) { return c2.score - c1.score; }); + var selected = []; + for (var i = 0; i < candidates.length; i++) { + var _a = candidates[i], score = _a.score, boxIndex = _a.boxIndex; + if (score < scoreThreshold) { + break; + } + var ignoreCandidate = false; + for (var j = selected.length - 1; j >= 0; --j) { + var iou = intersectionOverUnion(boxes, boxIndex, selected[j]); + if (iou >= iouThreshold) { + ignoreCandidate = true; + break; + } + } + if (!ignoreCandidate) { + selected.push(boxIndex); + if (selected.length >= maxOutputSize) { + break; + } + } + } + return tensor1d(selected, 'int32'); + } + function intersectionOverUnion(boxes, i, j) { + var iCoord = boxes.subarray(i * 4, i * 4 + 4); + var jCoord = boxes.subarray(j * 4, j * 4 + 4); + var yminI = Math.min(iCoord[0], iCoord[2]); + var xminI = Math.min(iCoord[1], iCoord[3]); + var ymaxI = Math.max(iCoord[0], iCoord[2]); + var xmaxI = Math.max(iCoord[1], iCoord[3]); + var yminJ = Math.min(jCoord[0], jCoord[2]); + var xminJ = Math.min(jCoord[1], jCoord[3]); + var ymaxJ = Math.max(jCoord[0], jCoord[2]); + var xmaxJ = Math.max(jCoord[1], jCoord[3]); + var areaI = (ymaxI - yminI) * (xmaxI - xminI); + var areaJ = (ymaxJ - yminJ) * (xmaxJ - xminJ); + if (areaI <= 0 || areaJ <= 0) { + return 0.0; + } + var intersectionYmin = Math.max(yminI, yminJ); + var intersectionXmin = Math.max(xminI, xminJ); + var intersectionYmax = Math.min(ymaxI, ymaxJ); + var intersectionXmax = Math.min(xmaxI, xmaxJ); + var intersectionArea = Math.max(intersectionYmax - intersectionYmin, 0.0) * + Math.max(intersectionXmax - intersectionXmin, 0.0); + return intersectionArea / (areaI + areaJ - intersectionArea); + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** Shared implementation of the split kernel across WebGL and CPU. */ + function split$1(x, sizeSplits, axis) { + var begin = new Array(x.rank).fill(0); + var size = x.shape.slice(); + return sizeSplits.map(function (s) { + size[axis] = s; + var slice = x.slice(begin, size); + begin[axis] += s; + return slice; + }); + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function tile$1(xBuf, reps) { + var newShape = new Array(xBuf.rank); + for (var i = 0; i < newShape.length; i++) { + newShape[i] = xBuf.shape[i] * reps[i]; + } + var result = buffer(newShape, xBuf.dtype); + for (var i = 0; i < result.values.length; ++i) { + var newLoc = result.indexToLoc(i); + var originalLoc = new Array(xBuf.rank); + for (var i_1 = 0; i_1 < originalLoc.length; i_1++) { + originalLoc[i_1] = newLoc[i_1] % xBuf.shape[i_1]; + } + var originalIndex = xBuf.locToIndex(originalLoc); + result.values[i] = xBuf.values[originalIndex]; + } + return result.toTensor(); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function topkImpl(x, xShape, xDtype, k, sorted) { + // Reshape into a 2d tensor [batch, lastDim] and compute topk along lastDim. + var lastDim = xShape[xShape.length - 1]; + var _a = [x.length / lastDim, lastDim], batch = _a[0], size = _a[1]; + var allTopKVals = getTypedArrayFromDType(xDtype, batch * k); + var allTopKIndices = getTypedArrayFromDType('int32', batch * k); + for (var b = 0; b < batch; b++) { + var offset = b * size; + var vals = x.subarray(offset, offset + size); + var valAndInd = []; + for (var i = 0; i < vals.length; i++) { + valAndInd.push({ value: vals[i], index: i }); + } + valAndInd.sort(function (a, b) { return b.value - a.value; }); + var outOffset = b * k; + var topKVals = allTopKVals.subarray(outOffset, outOffset + k); + var topKIndices = allTopKIndices.subarray(outOffset, outOffset + k); + for (var i = 0; i < k; i++) { + topKVals[i] = valAndInd[i].value; + topKIndices[i] = valAndInd[i].index; + } + } + // Reshape back to the original input shape, except that the last + // dimension is k. + var outputShape = xShape.slice(); + outputShape[outputShape.length - 1] = k; + return [ + tensor(allTopKVals, outputShape, xDtype), + tensor(allTopKIndices, outputShape, 'int32') + ]; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function whereImpl(condShape, condVals) { + var indices = []; + for (var i = 0; i < condVals.length; i++) { + if (condVals[i]) { + indices.push(i); + } + } + var inBuffer = buffer(condShape, 'int32'); + var out = buffer([indices.length, condShape.length], 'int32'); + for (var i = 0; i < indices.length; i++) { + var loc = inBuffer.indexToLoc(indices[i]); + var offset = i * condShape.length; + out.values.set(loc, offset); + } + return out.toTensor(); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var AddNProgram = /** @class */ (function () { + function AddNProgram(outputShape, shapes) { + this.outputShape = []; + this.outputShape = outputShape; + this.variableNames = shapes.map(function (_, i) { return "T" + i; }); + var snippets = []; + // Get target elements from every input tensor. + this.variableNames.forEach(function (variable) { + snippets.push("float v" + variable + " = get" + variable + "AtOutCoords();"); + }); + // Calculate the sum of all elements. + var operation = this.variableNames + .map(function (variable) { + return "v" + variable; + }) + .join(' + '); + this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n float result = " + operation + ";\n setOutput(result);\n }\n "; + } + return AddNProgram; + }()); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var AddNPackedProgram = /** @class */ (function () { + function AddNPackedProgram(outputShape, shapes) { + this.outputShape = []; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = outputShape; + this.variableNames = shapes.map(function (_, i) { return "T" + i; }); + var snippets = []; + // Get target elements from every input tensor. + this.variableNames.forEach(function (variable) { + snippets.push("vec4 v" + variable + " = get" + variable + "AtOutCoords();"); + }); + // Calculate the sum of all elements. + var operation = this.variableNames + .map(function (variable) { + return "v" + variable; + }) + .join(' + '); + this.userCode = "\n void main() {\n " + snippets.join('\n ') + "\n\n vec4 result = " + operation + ";\n setOutput(result);\n }\n "; + } + return AddNPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ArgMinMaxProgram = /** @class */ (function () { + function ArgMinMaxProgram(reduceInfo, op, firstPass) { + this.variableNames = ['A']; + var windowSize = reduceInfo.windowSize; + var batchSize = reduceInfo.batchSize; + var inSize = reduceInfo.inSize; + var outSize = Math.ceil(inSize / windowSize); + if (!firstPass) { + this.variableNames.push('bestIndicesA'); + } + this.outputShape = [batchSize, outSize]; + var compOp = (op === 'max') ? '>' : '<'; + var indexSnippet = firstPass ? + 'inOffset + i;' : + 'round(getBestIndicesA(batch, inOffset + i));'; + this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n int bestIndex = inOffset;\n float bestValue = getA(batch, bestIndex);\n\n for (int i = 0; i < " + windowSize + "; i++) {\n int inIdx = " + indexSnippet + ";\n float candidate = getA(batch, inIdx);\n if (candidate " + compOp + " bestValue) {\n bestValue = candidate;\n bestIndex = inIdx;\n }\n }\n setOutput(float(bestIndex));\n }\n "; + } + return ArgMinMaxProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function getVecChannels(name, rank) { + return ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank).map(function (d) { return name + "." + d; }); + } + function getChannels(name, rank) { + if (rank === 1) { + return [name]; + } + return getVecChannels(name, rank); + } + function getSourceCoords(rank, dims) { + if (rank === 1) { + return 'rc'; + } + var coords = ''; + for (var i = 0; i < rank; i++) { + coords += dims[i]; + if (i < rank - 1) { + coords += ','; + } + } + return coords; + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function getGlslDifferences() { + var version; + var attribute; + var varyingVs; + var varyingFs; + var texture2D; + var output; + var defineOutput; + var defineSpecialNaN; + var defineSpecialInf; + var defineRound; + if (env().getNumber('WEBGL_VERSION') === 2) { + version = '#version 300 es'; + attribute = 'in'; + varyingVs = 'out'; + varyingFs = 'in'; + texture2D = 'texture'; + output = 'outputColor'; + defineOutput = 'out vec4 outputColor;'; + // Use custom isnan definition to work across differences between + // implementations on various platforms. While this should happen in ANGLE + // we still see differences between android and windows (on chrome) when + // using isnan directly. + defineSpecialNaN = "\n bool isnan_custom(float val) {\n return (val > 0.0 || val < 0.0) ? false : val != 0.0;\n }\n\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan_custom(val.x),\n isnan_custom(val.y), isnan_custom(val.z), isnan_custom(val.w));\n }\n\n #define isnan(value) isnan_custom(value)\n "; + // In webgl 2 we do not need to specify a custom isinf so there is no + // need for a special INFINITY constant. + defineSpecialInf = ""; + defineRound = "\n #define round(value) newRound(value)\n int newRound(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 newRound(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n "; + } + else { + version = ''; + attribute = 'attribute'; + varyingVs = 'varying'; + varyingFs = 'varying'; + texture2D = 'texture2D'; + output = 'gl_FragColor'; + defineOutput = ''; + // WebGL1 has no built in isnan so we define one here. + defineSpecialNaN = "\n #define isnan(value) isnan_custom(value)\n bool isnan_custom(float val) {\n return (val > 0. || val < 1. || val == 0.) ? false : true;\n }\n bvec4 isnan_custom(vec4 val) {\n return bvec4(isnan(val.x), isnan(val.y), isnan(val.z), isnan(val.w));\n }\n "; + defineSpecialInf = "\n uniform float INFINITY;\n\n bool isinf(float val) {\n return abs(val) == INFINITY;\n }\n bvec4 isinf(vec4 val) {\n return equal(abs(val), vec4(INFINITY));\n }\n "; + defineRound = "\n int round(float value) {\n return int(floor(value + 0.5));\n }\n\n ivec4 round(vec4 value) {\n return ivec4(floor(value + vec4(0.5)));\n }\n "; + } + return { + version: version, + attribute: attribute, + varyingVs: varyingVs, + varyingFs: varyingFs, + texture2D: texture2D, + output: output, + defineOutput: defineOutput, + defineSpecialNaN: defineSpecialNaN, + defineSpecialInf: defineSpecialInf, + defineRound: defineRound + }; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Produces GLSL code that derives logical coordinates from a flat + * index. The code performs integer division with each stride and decrements + * the index until the index equals the final dimension coordinate. + */ + function getLogicalCoordinatesFromFlatIndex(coords, shape, index) { + if (index === void 0) { index = 'index'; } + var strides = computeStrides(shape); + return strides + .map(function (stride, i) { + var line1 = "int " + coords[i] + " = " + index + " / " + stride; + var line2 = i === strides.length - 1 ? + "int " + coords[i + 1] + " = " + index + " - " + coords[i] + " * " + stride : + "index -= " + coords[i] + " * " + stride; + return line1 + "; " + line2 + ";"; + }) + .join(''); + } + /** + * Produces GLSL that computes the flat index from 3D coordinates. + */ + function getFlatIndexFrom3D(shape) { + var strides = computeStrides(shape).map(function (d) { return d.toString(); }); + return "\n int getFlatIndex(ivec3 coords) {\n return coords.x * " + strides[0] + " + coords.y * " + strides[1] + " + coords.z;\n }\n"; + } + var ENCODE_FLOAT_SNIPPET = "\n const float FLOAT_MAX = 1.70141184e38;\n const float FLOAT_MIN = 1.17549435e-38;\n\n lowp vec4 encode_float(highp float v) {\n if (isnan(v)) {\n return vec4(255, 255, 255, 255);\n }\n\n highp float av = abs(v);\n\n if(av < FLOAT_MIN) {\n return vec4(0.0, 0.0, 0.0, 0.0);\n } else if(v > FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 127.0) / 255.0;\n } else if(v < -FLOAT_MAX) {\n return vec4(0.0, 0.0, 128.0, 255.0) / 255.0;\n }\n\n highp vec4 c = vec4(0,0,0,0);\n\n highp float e = floor(log2(av));\n highp float m = exp2(fract(log2(av))) - 1.0;\n\n c[2] = floor(128.0 * m);\n m -= c[2] / 128.0;\n c[1] = floor(32768.0 * m);\n m -= c[1] / 32768.0;\n c[0] = floor(8388608.0 * m);\n\n highp float ebias = e + 127.0;\n c[3] = floor(ebias / 2.0);\n ebias -= c[3] * 2.0;\n c[2] += floor(ebias) * 128.0;\n\n c[3] += 128.0 * step(0.0, -v);\n\n return c / 255.0;\n }\n"; + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function makeShader(inputsInfo, outputShape, userCode, usesPackedTextures) { + var prefixSnippets = []; + inputsInfo.forEach(function (x) { + var size = sizeFromShape(x.shapeInfo.logicalShape); + // Snippet when we decided to upload the values as uniform. + if (x.shapeInfo.isUniform) { + prefixSnippets.push("uniform float " + x.name + (size > 1 ? "[" + size + "]" : '') + ";"); + } + else { + prefixSnippets.push("uniform sampler2D " + x.name + ";"); + prefixSnippets.push("uniform int offset" + x.name + ";"); + } + }); + var inputPrefixSnippet = prefixSnippets.join('\n'); + var inputSamplingSnippet = inputsInfo + .map(function (x) { return getInputSamplingSnippet(x, outputShape, usesPackedTextures); }) + .join('\n'); + var outTexShape = outputShape.texShape; + var glsl = getGlslDifferences(); + var floatTextureSampleSnippet = getFloatTextureSampleSnippet(glsl); + var outputSamplingSnippet; + var floatTextureSetOutputSnippet; + var shaderPrefix = getShaderPrefix(glsl); + if (outputShape.isPacked) { + outputSamplingSnippet = + getPackedOutputSamplingSnippet(outputShape.logicalShape, outTexShape); + floatTextureSetOutputSnippet = getFloatTextureSetRGBASnippet(glsl); + } + else { + outputSamplingSnippet = + getOutputSamplingSnippet(outputShape.logicalShape, outTexShape); + floatTextureSetOutputSnippet = getFloatTextureSetRSnippet(glsl); + } + if (usesPackedTextures) { + shaderPrefix += SHADER_PACKED_PREFIX; + } + var source = [ + shaderPrefix, floatTextureSampleSnippet, floatTextureSetOutputSnippet, + inputPrefixSnippet, outputSamplingSnippet, inputSamplingSnippet, userCode + ].join('\n'); + return source; + } + function getSamplerFromInInfo(inInfo) { + var shape = inInfo.shapeInfo.logicalShape; + switch (shape.length) { + case 0: + return getSamplerScalar(inInfo); + case 1: + return getSampler1D(inInfo); + case 2: + return getSampler2D(inInfo); + case 3: + return getSampler3D(inInfo); + case 4: + return getSampler4D(inInfo); + case 5: + return getSampler5D(inInfo); + case 6: + return getSampler6D(inInfo); + default: + throw new Error(shape.length + "-D input sampling" + + " is not yet supported"); + } + } + function getPackedSamplerFromInInfo(inInfo) { + var shape = inInfo.shapeInfo.logicalShape; + switch (shape.length) { + case 0: + return getPackedSamplerScalar(inInfo); + case 1: + return getPackedSampler1D(inInfo); + case 2: + return getPackedSampler2D(inInfo); + case 3: + return getPackedSampler3D(inInfo); + default: + return getPackedSamplerND(inInfo); + } + } + function getInputSamplingSnippet(inInfo, outShapeInfo, usesPackedTextures) { + if (usesPackedTextures === void 0) { usesPackedTextures = false; } + var res = ''; + if (usesPackedTextures) { + res += getPackedSamplerFromInInfo(inInfo); + } + else { + res += getSamplerFromInInfo(inInfo); + } + var inShape = inInfo.shapeInfo.logicalShape; + var outShape = outShapeInfo.logicalShape; + if (inShape.length <= outShape.length) { + if (usesPackedTextures) { + res += getPackedSamplerAtOutputCoords(inInfo, outShapeInfo); + } + else { + res += getSamplerAtOutputCoords(inInfo, outShapeInfo); + } + } + return res; + } + function getPackedOutputSamplingSnippet(outShape, outTexShape) { + switch (outShape.length) { + case 0: + return getOutputScalarCoords(); + case 1: + return getOutputPacked1DCoords(outShape, outTexShape); + case 2: + return getOutputPacked2DCoords(outShape, outTexShape); + case 3: + return getOutputPacked3DCoords(outShape, outTexShape); + default: + return getOutputPackedNDCoords(outShape, outTexShape); + } + } + function getOutputSamplingSnippet(outShape, outTexShape) { + switch (outShape.length) { + case 0: + return getOutputScalarCoords(); + case 1: + return getOutput1DCoords(outShape, outTexShape); + case 2: + return getOutput2DCoords(outShape, outTexShape); + case 3: + return getOutput3DCoords(outShape, outTexShape); + case 4: + return getOutput4DCoords(outShape, outTexShape); + case 5: + return getOutput5DCoords(outShape, outTexShape); + case 6: + return getOutput6DCoords(outShape, outTexShape); + default: + throw new Error(outShape.length + "-D output sampling is not yet supported"); + } + } + function getFloatTextureSampleSnippet(glsl) { + return "\n float sampleTexture(sampler2D textureSampler, vec2 uv) {\n return " + glsl.texture2D + "(textureSampler, uv).r;\n }\n "; + } + function getFloatTextureSetRSnippet(glsl) { + return "\n void setOutput(float val) {\n " + glsl.output + " = vec4(val, 0, 0, 0);\n }\n "; + } + function getFloatTextureSetRGBASnippet(glsl) { + return "\n void setOutput(vec4 val) {\n " + glsl.output + " = val;\n }\n "; + } + function getShaderPrefix(glsl) { + var SHADER_PREFIX = glsl.version + "\n precision highp float;\n precision highp int;\n precision highp sampler2D;\n " + glsl.varyingFs + " vec2 resultUV;\n " + glsl.defineOutput + "\n const vec2 halfCR = vec2(0.5, 0.5);\n\n struct ivec5\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n };\n\n struct ivec6\n {\n int x;\n int y;\n int z;\n int w;\n int u;\n int v;\n };\n\n uniform float NAN;\n " + glsl.defineSpecialNaN + "\n " + glsl.defineSpecialInf + "\n " + glsl.defineRound + "\n\n int imod(int x, int y) {\n return x - y * (x / y);\n }\n\n int idiv(int a, int b, float sign) {\n int res = a / b;\n int mod = imod(a, b);\n if (sign < 0. && mod != 0) {\n res -= 1;\n }\n return res;\n }\n\n //Based on the work of Dave Hoskins\n //https://www.shadertoy.com/view/4djSRW\n #define HASHSCALE1 443.8975\n float random(float seed){\n vec2 p = resultUV * seed;\n vec3 p3 = fract(vec3(p.xyx) * HASHSCALE1);\n p3 += dot(p3, p3.yzx + 19.19);\n return fract((p3.x + p3.y) * p3.z);\n }\n\n " + SAMPLE_1D_SNIPPET + "\n " + SAMPLE_2D_SNIPPET + "\n " + SAMPLE_3D_SNIPPET + "\n "; + return SHADER_PREFIX; + } + var SAMPLE_1D_SNIPPET = "\nvec2 uvFromFlat(int texNumR, int texNumC, int index) {\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\nvec2 packedUVfrom1D(int texNumR, int texNumC, int index) {\n int texelIndex = index / 2;\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; + var SAMPLE_2D_SNIPPET = "\nvec2 packedUVfrom2D(int texelsInLogicalRow, int texNumR,\n int texNumC, int row, int col) {\n int texelIndex = (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = texelIndex / texNumC;\n int texC = texelIndex - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; + var SAMPLE_3D_SNIPPET = "\nvec2 packedUVfrom3D(int texNumR, int texNumC,\n int texelsInBatch, int texelsInLogicalRow, int b,\n int row, int col) {\n int index = b * texelsInBatch + (row / 2) * texelsInLogicalRow + (col / 2);\n int texR = index / texNumC;\n int texC = index - texR * texNumC;\n return (vec2(texC, texR) + halfCR) / vec2(texNumC, texNumR);\n}\n"; + var SHADER_PACKED_PREFIX = "\n float getChannel(vec4 frag, vec2 innerDims) {\n vec2 modCoord = mod(innerDims, 2.);\n return modCoord.x == 0. ?\n (modCoord.y == 0. ? frag.r : frag.g) :\n (modCoord.y == 0. ? frag.b : frag.a);\n }\n float getChannel(vec4 frag, int dim) {\n float modCoord = mod(float(dim), 2.);\n return modCoord == 0. ? frag.r : frag.g;\n }\n"; + function getOutputScalarCoords() { + return "\n int getOutputCoords() {\n return 0;\n }\n "; + } + function getOutputPacked1DCoords(shape, texShape) { + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + if (packedTexShape[0] === 1) { + return "\n int getOutputCoords() {\n return 2 * int(resultUV.x * " + packedTexShape[1] + ".0);\n }\n "; + } + if (packedTexShape[1] === 1) { + return "\n int getOutputCoords() {\n return 2 * int(resultUV.y * " + packedTexShape[0] + ".0);\n }\n "; + } + return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n return 2 * (resTexRC.x * " + packedTexShape[1] + " + resTexRC.y);\n }\n "; + } + function getOutput1DCoords(shape, texShape) { + if (texShape[0] === 1) { + return "\n int getOutputCoords() {\n return int(resultUV.x * " + texShape[1] + ".0);\n }\n "; + } + if (texShape[1] === 1) { + return "\n int getOutputCoords() {\n return int(resultUV.y * " + texShape[0] + ".0);\n }\n "; + } + return "\n int getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n return resTexRC.x * " + texShape[1] + " + resTexRC.y;\n }\n "; + } + function getOutputPacked3DCoords(shape, texShape) { + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + var texelsInLogicalRow = Math.ceil(shape[2] / 2); + var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[1] / 2); + return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec3(b, r, c);\n }\n "; + } + function getOutput3DCoords(shape, texShape) { + var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); + return "\n ivec3 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n "; + } + function getOutputPackedNDCoords(shape, texShape) { + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + var texelsInLogicalRow = Math.ceil(shape[shape.length - 1] / 2); + var texelsInBatch = texelsInLogicalRow * Math.ceil(shape[shape.length - 2] / 2); + var texelsInBatchN = texelsInBatch; + var batches = ""; + var coords = 'b, r, c'; + for (var b = 2; b < shape.length - 1; b++) { + texelsInBatchN *= shape[shape.length - b - 1]; + batches = "\n int b" + b + " = index / " + texelsInBatchN + ";\n index -= b" + b + " * " + texelsInBatchN + ";\n " + batches; + coords = "b" + b + ", " + coords; + } + return "\n ivec" + shape.length + " getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n\n " + batches + "\n\n int b = index / " + texelsInBatch + ";\n index -= b * " + texelsInBatch + ";\n\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec" + shape.length + "(" + coords + ");\n }\n "; + } + function getOutput4DCoords(shape, texShape) { + var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2'], shape); + return "\n ivec4 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n " + coordsFromIndexSnippet + "\n return ivec4(r, c, d, d2);\n }\n "; + } + function getOutput5DCoords(shape, texShape) { + var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3'], shape); + return "\n ivec5 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx * vec2(" + texShape[0] + ",\n " + texShape[1] + "));\n\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec5 outShape = ivec5(r, c, d, d2, d3);\n return outShape;\n }\n "; + } + function getOutput6DCoords(shape, texShape) { + var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd', 'd2', 'd3', 'd4'], shape); + return "\n ivec6 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n\n " + coordsFromIndexSnippet + "\n\n ivec6 result = ivec6(r, c, d, d2, d3, d4);\n return result;\n }\n "; + } + function getOutputPacked2DCoords(shape, texShape) { + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + if (arraysEqual(shape, texShape)) { + return "\n ivec2 getOutputCoords() {\n return 2 * ivec2(resultUV.yx * vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n }\n "; + } + // texels needed to accommodate a logical row + var texelsInLogicalRow = Math.ceil(shape[1] / 2); + /** + * getOutputCoords + * + * resTexRC: The rows and columns of the texels. If you move over one + * texel to the right in the packed texture, you are moving over one column + * (not two). + * + * index: The texel index + */ + return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + packedTexShape[0] + ", " + packedTexShape[1] + "));\n\n int index = resTexRC.x * " + packedTexShape[1] + " + resTexRC.y;\n int r = 2 * (index / " + texelsInLogicalRow + ");\n int c = imod(index, " + texelsInLogicalRow + ") * 2;\n\n return ivec2(r, c);\n }\n "; + } + function getOutput2DCoords(shape, texShape) { + if (arraysEqual(shape, texShape)) { + return "\n ivec2 getOutputCoords() {\n return ivec2(resultUV.yx * vec2(" + texShape[0] + ", " + texShape[1] + "));\n }\n "; + } + if (shape[1] === 1) { + return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(index, 0);\n }\n "; + } + if (shape[0] === 1) { + return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n return ivec2(0, index);\n }\n "; + } + return "\n ivec2 getOutputCoords() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = resTexRC.x * " + texShape[1] + " + resTexRC.y;\n int r = index / " + shape[1] + ";\n int c = index - r * " + shape[1] + ";\n return ivec2(r, c);\n }\n "; + } + function getFlatOffsetUniformName(texName) { + return "offset" + texName; + } + function getPackedSamplerScalar(inputInfo) { + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var glsl = getGlslDifferences(); + return "\n vec4 " + funcName + "() {\n return " + glsl.texture2D + "(" + texName + ", halfCR);\n }\n "; + } + function getSamplerScalar(inputInfo) { + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + if (inputInfo.shapeInfo.isUniform) { + return "float " + funcName + "() {return " + texName + ";}"; + } + var _a = inputInfo.shapeInfo.texShape, texNumR = _a[0], texNumC = _a[1]; + if (texNumR === 1 && texNumC === 1) { + return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", halfCR);\n }\n "; + } + var _b = inputInfo.shapeInfo.texShape, tNumR = _b[0], tNumC = _b[1]; + var offset = getFlatOffsetUniformName(texName); + return "\n float " + funcName + "() {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + function getPackedSampler1D(inputInfo) { + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var texShape = inputInfo.shapeInfo.texShape; + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + var glsl = getGlslDifferences(); + return "\n vec4 " + funcName + "(int index) {\n vec2 uv = packedUVfrom1D(\n " + packedTexShape[0] + ", " + packedTexShape[1] + ", index);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; + } + function getSampler1D(inputInfo) { + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + if (inputInfo.shapeInfo.isUniform) { + // Uniform arrays will be less than 65505 (no risk of float16 overflow). + return "\n float " + funcName + "(int index) {\n " + getUniformSampler(inputInfo) + "\n }\n "; + } + var texShape = inputInfo.shapeInfo.texShape; + var tNumR = texShape[0]; + var tNumC = texShape[1]; + if (tNumC === 1 && tNumR === 1) { + return "\n float " + funcName + "(int index) {\n return sampleTexture(" + texName + ", halfCR);\n }\n "; + } + var offset = getFlatOffsetUniformName(texName); + if (tNumC === 1) { + return "\n float " + funcName + "(int index) {\n vec2 uv = vec2(0.5, (float(index + " + offset + ") + 0.5) / " + tNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + if (tNumR === 1) { + return "\n float " + funcName + "(int index) {\n vec2 uv = vec2((float(index + " + offset + ") + 0.5) / " + tNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + return "\n float " + funcName + "(int index) {\n vec2 uv = uvFromFlat(" + tNumR + ", " + tNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + function getPackedSampler2D(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var texShape = inputInfo.shapeInfo.texShape; + var texNumR = texShape[0]; + var texNumC = texShape[1]; + var glsl = getGlslDifferences(); + if (texShape != null && arraysEqual(shape, texShape)) { + return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; + } + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + var valuesPerRow = Math.ceil(shape[1] / 2); + return "\n vec4 " + funcName + "(int row, int col) {\n vec2 uv = packedUVfrom2D(" + valuesPerRow + ", " + packedTexShape[0] + ", " + packedTexShape[1] + ", row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; + } + function getSampler2D(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var texShape = inputInfo.shapeInfo.texShape; + if (texShape != null && arraysEqual(shape, texShape)) { + var texNumR_1 = texShape[0]; + var texNumC_1 = texShape[1]; + return "\n float " + funcName + "(int row, int col) {\n vec2 uv = (vec2(col, row) + halfCR) / vec2(" + texNumC_1 + ".0, " + texNumR_1 + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; + var squeezedShape = newShape; + if (squeezedShape.length < shape.length) { + var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); + var params = ['row', 'col']; + return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; + } + if (inputInfo.shapeInfo.isUniform) { + // Uniform arrays will be less than 65505 (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col) {\n int index = round(dot(vec2(row, col), vec2(" + shape[1] + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; + } + var texNumR = texShape[0]; + var texNumC = texShape[1]; + var offset = getFlatOffsetUniformName(texName); + if (texNumC === 1) { + // index is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2(0.5, (index + 0.5) / " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + if (texNumR === 1) { + // index is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col) {\n float index = dot(vec3(row, col, " + offset + "), vec3(" + shape[1] + ", 1, 1));\n vec2 uv = vec2((index + 0.5) / " + texNumC + ".0, 0.5);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + return "\n float " + funcName + "(int row, int col) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + shape[1] + " + col + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n"; + } + function getPackedSampler3D(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var texShape = inputInfo.shapeInfo.texShape; + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + if (shape[0] === 1) { + var squeezedShape = shape.slice(1); + var keptDims = [1, 2]; + var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); + var params = ['b', 'row', 'col']; + return "\n " + getPackedSamplerFromInInfo(newInputInfo) + "\n vec4 " + funcName + "(int b, int row, int col) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; + } + var texNumR = packedTexShape[0]; + var texNumC = packedTexShape[1]; + var valuesPerRow = Math.ceil(shape[2] / 2); + var texelsInBatch = valuesPerRow * Math.ceil(shape[1] / 2); + var glsl = getGlslDifferences(); + return "\n vec4 " + funcName + "(int b, int row, int col) {\n vec2 uv = packedUVfrom3D(\n " + texNumR + ", " + texNumC + ", " + texelsInBatch + ", " + valuesPerRow + ", b, row, col);\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; + } + function getSampler3D(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var stride0 = shape[1] * shape[2]; + var stride1 = shape[2]; + var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; + var squeezedShape = newShape; + if (squeezedShape.length < shape.length) { + var newInputInfo = squeezeInputInfo(inputInfo, squeezedShape); + var params = ['row', 'col', 'depth']; + return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; + } + if (inputInfo.shapeInfo.isUniform) { + // Uniform arrays will be less than 65505 (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth) {\n int index = round(dot(vec3(row, col, depth),\n vec3(" + stride0 + ", " + stride1 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; + } + var texShape = inputInfo.shapeInfo.texShape; + var texNumR = texShape[0]; + var texNumC = texShape[1]; + var flatOffset = inputInfo.shapeInfo.flatOffset; + if (texNumC === stride0 && flatOffset == null) { + // texC is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = float(row);\n float texC = dot(vec2(col, depth), vec2(" + stride1 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + if (texNumC === stride1 && flatOffset == null) { + // texR is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth) {\n float texR = dot(vec2(row, col), vec2(" + shape[1] + ", 1));\n float texC = float(depth);\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + var offset = getFlatOffsetUniformName(texName); + return "\n float " + funcName + "(int row, int col, int depth) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + function getPackedSamplerND(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var rank = shape.length; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var texShape = inputInfo.shapeInfo.texShape; + var packedTexShape = [Math.ceil(texShape[0] / 2), Math.ceil(texShape[1] / 2)]; + var texNumR = packedTexShape[0]; + var texNumC = packedTexShape[1]; + var valuesPerRow = Math.ceil(shape[rank - 1] / 2); + var texelsInBatch = valuesPerRow * Math.ceil(shape[rank - 2] / 2); + var params = "int b, int row, int col"; + var index = "b * " + texelsInBatch + " + (row / 2) * " + valuesPerRow + " + (col / 2)"; + for (var b = 2; b < rank - 1; b++) { + params = "int b" + b + ", " + params; + texelsInBatch *= shape[rank - b - 1]; + index = "b" + b + " * " + texelsInBatch + " + " + index; + } + var glsl = getGlslDifferences(); + return "\n vec4 " + funcName + "(" + params + ") {\n int index = " + index + ";\n int texR = index / " + texNumC + ";\n int texC = index - texR * " + texNumC + ";\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + texNumC + ", " + texNumR + ");\n return " + glsl.texture2D + "(" + texName + ", uv);\n }\n "; + } + function getSampler4D(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var stride2 = shape[3]; + var stride1 = shape[2] * stride2; + var stride0 = shape[1] * stride1; + var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; + if (newShape.length < shape.length) { + var newInputInfo = squeezeInputInfo(inputInfo, newShape); + var params = ['row', 'col', 'depth', 'depth2']; + return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; + } + if (inputInfo.shapeInfo.isUniform) { + // Uniform arrays will be less than 65505 (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n int index = round(dot(vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; + } + var flatOffset = inputInfo.shapeInfo.flatOffset; + var texShape = inputInfo.shapeInfo.texShape; + var texNumR = texShape[0]; + var texNumC = texShape[1]; + if (texNumC === stride0 && flatOffset == null) { + // texC is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = float(row);\n float texC =\n dot(vec3(col, depth, depth2),\n vec3(" + stride1 + ", " + stride2 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + if (texNumC === stride2 && flatOffset == null) { + // texR is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n float texR = dot(vec3(row, col, depth),\n vec3(" + shape[1] * shape[2] + ", " + shape[2] + ", 1));\n float texC = float(depth2);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + var offset = getFlatOffsetUniformName(texName); + return "\n float " + funcName + "(int row, int col, int depth, int depth2) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " +\n depth * " + stride2 + " + depth2;\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index + " + offset + ");\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + function getSampler5D(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var stride3 = shape[4]; + var stride2 = shape[3] * stride3; + var stride1 = shape[2] * stride2; + var stride0 = shape[1] * stride1; + var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; + if (newShape.length < shape.length) { + var newInputInfo = squeezeInputInfo(inputInfo, newShape); + var params = ['row', 'col', 'depth', 'depth2', 'depth3']; + return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; + } + if (inputInfo.shapeInfo.isUniform) { + // Uniform arrays will be less than 65505 (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float index = dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n depth3;\n " + getUniformSampler(inputInfo) + "\n }\n "; + } + var flatOffset = inputInfo.shapeInfo.flatOffset; + var texShape = inputInfo.shapeInfo.texShape; + var texNumR = texShape[0]; + var texNumC = texShape[1]; + if (texNumC === stride0 && flatOffset == null) { + // texC is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", 1));\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + if (texNumC === stride3 && flatOffset == null) { + // texR is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n float texR = dot(\n vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] + ",\n " + shape[2] * shape[3] + ", " + shape[3] + ", 1));\n int texC = depth3;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + var offset = getFlatOffsetUniformName(texName); + return "\n float " + funcName + "(int row, int col, int depth, int depth2, int depth3) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + function getSampler6D(inputInfo) { + var shape = inputInfo.shapeInfo.logicalShape; + var texName = inputInfo.name; + var funcName = 'get' + texName.charAt(0).toUpperCase() + texName.slice(1); + var _a = squeezeShape(shape), newShape = _a.newShape, keptDims = _a.keptDims; + if (newShape.length < shape.length) { + var newInputInfo = squeezeInputInfo(inputInfo, newShape); + var params = ['row', 'col', 'depth', 'depth2', 'depth3', 'depth4']; + return "\n " + getSamplerFromInInfo(newInputInfo) + "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n return " + funcName + "(" + getSqueezedParams(params, keptDims) + ");\n }\n "; + } + var stride4 = shape[5]; + var stride3 = shape[4] * stride4; + var stride2 = shape[3] * stride3; + var stride1 = shape[2] * stride2; + var stride0 = shape[1] * stride1; + if (inputInfo.shapeInfo.isUniform) { + // Uniform arrays will be less than 65505 (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int index = round(dot(\n vec4(row, col, depth, depth2),\n vec4(" + stride0 + ", " + stride1 + ", " + stride2 + ", " + stride3 + ")) +\n dot(\n vec2(depth3, depth4),\n vec2(" + stride4 + ", 1)));\n " + getUniformSampler(inputInfo) + "\n }\n "; + } + var flatOffset = inputInfo.shapeInfo.flatOffset; + var texShape = inputInfo.shapeInfo.texShape; + var texNumR = texShape[0]; + var texNumC = texShape[1]; + if (texNumC === stride0 && flatOffset == null) { + // texC is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n int texR = row;\n float texC = dot(vec4(col, depth, depth2, depth3),\n vec4(" + stride1 + ", " + stride2 + ", " + stride3 + ", " + stride4 + ")) +\n float(depth4);\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + if (texNumC === stride4 && flatOffset == null) { + // texR is used directly as physical (no risk of float16 overflow). + return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n float texR = dot(vec4(row, col, depth, depth2),\n vec4(" + shape[1] * shape[2] * shape[3] * shape[4] + ",\n " + shape[2] * shape[3] * shape[4] + ",\n " + shape[3] * shape[4] + ",\n " + shape[4] + ")) + float(depth3);\n int texC = depth4;\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + texNumC + ".0, " + texNumR + ".0);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + var offset = getFlatOffsetUniformName(texName); + return "\n float " + funcName + "(int row, int col, int depth,\n int depth2, int depth3, int depth4) {\n // Explicitly use integer operations as dot() only works on floats.\n int index = row * " + stride0 + " + col * " + stride1 + " + depth * " + stride2 + " +\n depth2 * " + stride3 + " + depth3 * " + stride4 + " + depth4 + " + offset + ";\n vec2 uv = uvFromFlat(" + texNumR + ", " + texNumC + ", index);\n return sampleTexture(" + texName + ", uv);\n }\n "; + } + function getUniformSampler(inputInfo) { + var texName = inputInfo.name; + var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape); + if (inSize < 2) { + return "return " + texName + ";"; + } + return "\n for (int i = 0; i < " + inSize + "; i++) {\n if (i == index) {\n return " + texName + "[i];\n }\n }\n "; + } + function getPackedSamplerAtOutputCoords(inputInfo, outShapeInfo) { + var texName = inputInfo.name; + var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); + var funcName = 'get' + texFuncSnippet + 'AtOutCoords'; + var inRank = inputInfo.shapeInfo.logicalShape.length; + var outRank = outShapeInfo.logicalShape.length; + var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); + var type = getCoordsDataType(outRank); + var rankDiff = outRank - inRank; + var coordsSnippet; + var fields = ['x', 'y', 'z', 'w', 'u', 'v']; + if (inRank === 0) { + coordsSnippet = ''; + } + else if (outRank < 2 && broadcastDims.length >= 1) { + coordsSnippet = 'coords = 0;'; + } + else { + coordsSnippet = + broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; }) + .join('\n'); + } + var unpackedCoordsSnippet = ''; + if (outRank < 2 && inRank > 0) { + unpackedCoordsSnippet = 'coords'; + } + else { + unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape + .map(function (s, i) { return "coords." + fields[i + rankDiff]; }) + .join(', '); + } + var output = "return outputValue;"; + var inSize = sizeFromShape(inputInfo.shapeInfo.logicalShape); + var isInputScalar = inSize === 1; + var outSize = sizeFromShape(outShapeInfo.logicalShape); + var isOutputScalar = outSize === 1; + if (inRank === 1 && !isInputScalar && !isOutputScalar) { + output = "\n return vec4(outputValue.xy, outputValue.xy);\n "; + } + else if (isInputScalar && !isOutputScalar) { + if (outRank === 1) { + output = "\n return vec4(outputValue.x, outputValue.x, 0., 0.);\n "; + } + else { + output = "\n return vec4(outputValue.x);\n "; + } + } + else if (broadcastDims.length) { + var rows = inRank - 2; + var cols = inRank - 1; + if (broadcastDims.indexOf(rows) > -1 && broadcastDims.indexOf(cols) > -1) { + output = "return vec4(outputValue.x);"; + } + else if (broadcastDims.indexOf(rows) > -1) { + output = "return vec4(outputValue.x, outputValue.y, " + + "outputValue.x, outputValue.y);"; + } + else if (broadcastDims.indexOf(cols) > -1) { + output = "return vec4(outputValue.xx, outputValue.zz);"; + } + } + return "\n vec4 " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n vec4 outputValue = get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n " + output + "\n }\n "; + } + function getSamplerAtOutputCoords(inputInfo, outShapeInfo) { + var texName = inputInfo.name; + var texFuncSnippet = texName.charAt(0).toUpperCase() + texName.slice(1); + var funcName = 'get' + texFuncSnippet + 'AtOutCoords'; + var outTexShape = outShapeInfo.texShape; + var inTexShape = inputInfo.shapeInfo.texShape; + var inRank = inputInfo.shapeInfo.logicalShape.length; + var outRank = outShapeInfo.logicalShape.length; + if (!inputInfo.shapeInfo.isUniform && inRank === outRank && + inputInfo.shapeInfo.flatOffset == null && + arraysEqual(inTexShape, outTexShape)) { + return "\n float " + funcName + "() {\n return sampleTexture(" + texName + ", resultUV);\n }\n "; + } + var type = getCoordsDataType(outRank); + var broadcastDims = getBroadcastDims(inputInfo.shapeInfo.logicalShape, outShapeInfo.logicalShape); + var rankDiff = outRank - inRank; + var coordsSnippet; + var fields = ['x', 'y', 'z', 'w', 'u', 'v']; + if (inRank === 0) { + coordsSnippet = ''; + } + else if (outRank < 2 && broadcastDims.length >= 1) { + coordsSnippet = 'coords = 0;'; + } + else { + coordsSnippet = + broadcastDims.map(function (d) { return "coords." + fields[d + rankDiff] + " = 0;"; }) + .join('\n'); + } + var unpackedCoordsSnippet = ''; + if (outRank < 2 && inRank > 0) { + unpackedCoordsSnippet = 'coords'; + } + else { + unpackedCoordsSnippet = inputInfo.shapeInfo.logicalShape + .map(function (s, i) { return "coords." + fields[i + rankDiff]; }) + .join(', '); + } + return "\n float " + funcName + "() {\n " + type + " coords = getOutputCoords();\n " + coordsSnippet + "\n return get" + texFuncSnippet + "(" + unpackedCoordsSnippet + ");\n }\n "; + } + function getCoordsDataType(rank) { + if (rank <= 1) { + return 'int'; + } + else if (rank === 2) { + return 'ivec2'; + } + else if (rank === 3) { + return 'ivec3'; + } + else if (rank === 4) { + return 'ivec4'; + } + else if (rank === 5) { + return 'ivec5'; + } + else if (rank === 6) { + return 'ivec6'; + } + else { + throw Error("GPU for rank " + rank + " is not yet supported"); + } + } + /** Returns a new input info (a copy) that has a squeezed logical shape. */ + function squeezeInputInfo(inInfo, squeezedShape) { + // Deep copy. + var newInputInfo = JSON.parse(JSON.stringify(inInfo)); + newInputInfo.shapeInfo.logicalShape = squeezedShape; + return newInputInfo; + } + function getSqueezedParams(params, keptDims) { + return keptDims.map(function (d) { return params[d]; }).join(', '); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ArgMinMaxPackedProgram = /** @class */ (function () { + function ArgMinMaxPackedProgram(shape, windowSize, op, firstPass) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + assert(shape.length > 2, function () { return "Packed arg" + (op.charAt(0).toUpperCase() + + op.slice(1)) + " supports only inputs with rank above 2."; }); + var inSize = shape[shape.length - 1]; + var outSize = Math.ceil(inSize / windowSize); + this.outputShape = shape.slice(0, -1); + if (outSize > 1) { + this.outputShape.push(outSize); + } + if (!firstPass) { + this.variableNames.push('bestIndicesA'); + } + var outShape = this.outputShape; + var rank = outShape.length; + var dtype = getCoordsDataType(rank); + var coords = getChannels('coords', rank); + var sourceLocSetup; + var sourceRank; + if (outSize === 1) { + sourceRank = rank + 1; + var sourceLocDType = getCoordsDataType(sourceRank); + sourceLocSetup = "\n " + sourceLocDType + " sourceLocR = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocG = " + sourceLocDType + "(" + coords.join() + ", 0);\n ++" + coords[rank - 2] + ";\n " + sourceLocDType + " sourceLocA = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 1] + ";\n " + sourceLocDType + " sourceLocB = " + sourceLocDType + "(" + coords.join() + ", 0);\n --" + coords[rank - 2] + ";"; + } + else { + sourceRank = rank; + sourceLocSetup = "\n " + dtype + " sourceLocR = coords;\n ++" + coords[rank - 1] + ";\n " + dtype + " sourceLocG = coords;\n ++" + coords[rank - 2] + ";\n " + dtype + " sourceLocA = coords;\n --" + coords[rank - 1] + ";\n " + dtype + " sourceLocB = coords;\n --" + coords[rank - 2] + ";"; + } + var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, sourceRank); + var inChannel = '.' + channels[sourceRank - 1]; // e.g. ".b" for rank 3. + var intChannels = channels.map(function (x) { return 'int ' + x; }); + var srcRCoords = getChannels('sourceLocR', sourceRank - 1).concat('inIdx.r'); + var srcGCoords = getChannels('sourceLocG', sourceRank - 1).concat('inIdx.g'); + var srcBCoords = getChannels('sourceLocB', sourceRank - 1).concat('inIdx.b'); + var srcACoords = getChannels('sourceLocA', sourceRank - 1).concat('inIdx.a'); + var compOp = (op === 'max') ? 'greaterThan' : 'lessThan'; + var fetchCandidateIdx = firstPass ? '' : "\n inIdx = round(vec4(getBestIndicesAChannel(" + srcRCoords.join() + "),\n getBestIndicesAChannel(" + srcGCoords.join() + "),\n getBestIndicesAChannel(" + srcBCoords.join() + "),\n getBestIndicesAChannel(" + srcACoords.join() + ")));"; + var fetchValue = "vec4(\n getAChannel(" + srcRCoords.join() + "),\n hasNextCol ? getAChannel(" + srcGCoords.join() + ") : 0.,\n hasNextRow ? getAChannel(" + srcBCoords.join() + ") : 0.,\n hasNextRow && hasNextCol ? getAChannel(" + srcACoords.join() + ") : 0.)"; + var getBestIndicesAChannelSnippet = firstPass ? '' : "\n float getBestIndicesAChannel(" + intChannels.join() + ") {\n return getChannel(getBestIndicesA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }"; + this.userCode = "\n float getAChannel(" + intChannels.join() + ") {\n return getChannel(getA(" + channels.join() + "),\n vec2(" + channels.slice(-2).join() + "));\n }\n " + getBestIndicesAChannelSnippet + "\n void main() {\n " + dtype + " coords = getOutputCoords();\n bool hasNextCol = " + coords[rank - 1] + " < " + (outShape[rank - 1] - 1) + ";\n bool hasNextRow = " + coords[rank - 2] + " < " + (outShape[rank - 2] - 1) + ";\n " + sourceLocSetup + "\n ivec4 srcIdx = ivec4(sourceLocR" + inChannel + ", sourceLocG" + inChannel + ",\n sourceLocB" + inChannel + ", sourceLocA" + inChannel + ") * " + windowSize + ";\n ivec4 inIdx = srcIdx;\n vec4 bestIndex = vec4(inIdx);\n vec4 bestValue = " + fetchValue + ";\n\n for (int i = 0; i < " + windowSize + "; i++) {\n inIdx = srcIdx;\n " + fetchCandidateIdx + "\n vec4 candidate = " + fetchValue + ";\n bvec4 nan = isnan(candidate);\n bvec4 replace = bvec4(\n vec4(" + compOp + "(candidate, bestValue)) * (vec4(1.0) - vec4(nan)));\n\n bestValue = vec4(replace.x ? candidate.x : bestValue.x,\n replace.y ? candidate.y : bestValue.y,\n replace.z ? candidate.z : bestValue.z,\n replace.w ? candidate.w : bestValue.w);\n bestIndex = mix(bestIndex, vec4(inIdx), vec4(replace));\n srcIdx++;\n }\n setOutput(bestIndex);\n }\n "; + } + return ArgMinMaxPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var AvgPool2DBackpropProgram = /** @class */ (function () { + function AvgPool2DBackpropProgram(convInfo) { + this.variableNames = ['dy']; + this.outputShape = convInfo.inShape; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var avgMultiplier = 1 / (filterHeight * filterWidth); + this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC+= " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n setOutput(dotProd);\n }\n "; + } + return AvgPool2DBackpropProgram; + }()); + var AvgPool3DBackpropProgram = /** @class */ (function () { + function AvgPool3DBackpropProgram(convInfo) { + this.variableNames = ['dy']; + this.outputShape = convInfo.inShape; + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth); + this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float avgMultiplier = float(" + avgMultiplier + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, d) with pos mask(:, :, :, ch) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n\n dotProd += dyValue * avgMultiplier;\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return AvgPool3DBackpropProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var BatchNormProgram = /** @class */ (function () { + function BatchNormProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { + this.outputShape = []; + this.variableNames = ['x', 'mean', 'variance']; + assertAndGetBroadcastShape(xShape, meanShape); + assertAndGetBroadcastShape(xShape, varianceShape); + var offsetSnippet = '0.0'; + if (offsetShape != null) { + assertAndGetBroadcastShape(xShape, offsetShape); + this.variableNames.push('offset'); + offsetSnippet = 'getOffsetAtOutCoords()'; + } + var scaleSnippet = '1.0'; + if (scaleShape != null) { + assertAndGetBroadcastShape(xShape, scaleShape); + this.variableNames.push('scale'); + scaleSnippet = 'getScaleAtOutCoords()'; + } + this.outputShape = xShape; + this.userCode = "\n void main() {\n float x = getXAtOutCoords();\n float mean = getMeanAtOutCoords();\n float variance = getVarianceAtOutCoords();\n float offset = " + offsetSnippet + ";\n float scale = " + scaleSnippet + ";\n float inv = scale * inversesqrt(variance + float(" + varianceEpsilon + "));\n setOutput(dot(vec3(x, -mean, offset), vec3(inv, inv, 1)));\n }\n "; + } + return BatchNormProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var BatchNormPackedProgram = /** @class */ (function () { + function BatchNormPackedProgram(xShape, meanShape, varianceShape, offsetShape, scaleShape, varianceEpsilon) { + this.packedInputs = true; + this.packedOutput = true; + this.variableNames = ['x', 'mean', 'variance']; + assertAndGetBroadcastShape(xShape, meanShape); + assertAndGetBroadcastShape(xShape, varianceShape); + var offsetSnippet = 'vec4(0.0)'; + if (offsetShape != null) { + assertAndGetBroadcastShape(xShape, offsetShape); + this.variableNames.push('offset'); + offsetSnippet = 'getOffsetAtOutCoords()'; + } + var scaleSnippet = 'vec4(1.0)'; + if (scaleShape != null) { + assertAndGetBroadcastShape(xShape, scaleShape); + this.variableNames.push('scale'); + scaleSnippet = 'getScaleAtOutCoords()'; + } + this.outputShape = xShape; + this.userCode = "\n void main() {\n vec4 offset = " + offsetSnippet + ";\n vec4 scale = " + scaleSnippet + ";\n\n vec4 x = getXAtOutCoords();\n vec4 mean = getMeanAtOutCoords();\n vec4 variance = getVarianceAtOutCoords();\n\n vec4 inv = scale * inversesqrt(variance + vec4(" + varianceEpsilon + "));\n\n setOutput((x - mean) * inv + offset);\n }\n "; + } + return BatchNormPackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // (Ar + Ai)(Br + Bi) = + // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr + // Yr = ArBr - AB + // Yi = ArBi + AiBr + var COMPLEX_MULTIPLY = { + REAL: 'return areal * breal - aimag * bimag;', + IMAG: 'return areal * bimag + aimag * breal;' + }; + var BinaryOpComplexProgram = /** @class */ (function () { + function BinaryOpComplexProgram(op, aShape, bShape) { + this.variableNames = ['AReal', 'AImag', 'BReal', 'BImag']; + this.outputShape = + assertAndGetBroadcastShape(aShape, bShape); + this.userCode = "\n float binaryOpComplex(\n float areal, float aimag, float breal, float bimag) {\n " + op + "\n }\n\n void main() {\n float areal = getARealAtOutCoords();\n float aimag = getAImagAtOutCoords();\n float breal = getBRealAtOutCoords();\n float bimag = getBImagAtOutCoords();\n setOutput(binaryOpComplex(areal, aimag, breal, bimag));\n }\n "; + } + return BinaryOpComplexProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var CHECK_NAN_SNIPPET = "\n if (isnan(a)) return a;\n if (isnan(b)) return b;\n"; + var ADD = 'return a + b;'; + var SUB = 'return a - b;'; + var MUL = 'return a * b;'; + // Without the equality check div produces 0.9999 for a = b, which when + // floored can cause errors. + var DIV = "\nif (a == b) {\n return 1.0;\n};\nreturn a / b;"; + // We use native integer division to deal with floating point imprecision. Since + // we implement floor division and glsl implements truncated division, we + // correct for this by subtracting 1 from result when the result is negative and + // there is a remainder. + var INT_DIV = "\n float s = sign(a) * sign(b);\n int ia = round(a);\n int ib = round(b);\n if (ib != 0) {\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n return float(idiv(ia, ib, s));\n } else {\n return NAN;\n }\n"; + var POW = "\nif(a < 0.0 && floor(b) < b){\n return NAN;\n}\nif (b == 0.0) {\n return 1.0;\n}\nreturn (round(mod(b, 2.0)) != 1) ?\n pow(abs(a), b) : sign(a) * pow(abs(a), b);\n"; + var SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; + var EQUAL = "return float(a == b);"; + var NOT_EQUAL = "return float(a != b);"; + var LESS = "return float(a < b);"; + var LESS_EQUAL = "return float(a <= b);"; + var GREATER = "return float(a > b);"; + var GREATER_EQUAL = "return float(a >= b);"; + var LOGICAL_AND = "return float(a >= 1.0 && b >= 1.0);"; + var LOGICAL_OR = "return float(a >= 1.0 || b >= 1.0);"; + var MAX = CHECK_NAN_SNIPPET + "\n return max(a, b);\n"; + var MIN = CHECK_NAN_SNIPPET + "\n return min(a, b);\n"; + var MOD = "if (b == 0.0) return NAN;\n return mod(a, b);"; + var ATAN2 = CHECK_NAN_SNIPPET + "\n return atan(a, b);\n"; + var ELU_DER = "return (b >= 1.0) ? a : a * (b + 1.0);"; + var PRELU = "return (a < 0.) ? b * a : a;"; + var BinaryOpProgram = /** @class */ (function () { + function BinaryOpProgram(op, aShape, bShape) { + this.variableNames = ['A', 'B']; + this.outputShape = + assertAndGetBroadcastShape(aShape, bShape); + this.userCode = "\n float binaryOperation(float a, float b) {\n " + op + "\n }\n\n void main() {\n float a = getAAtOutCoords();\n float b = getBAtOutCoords();\n setOutput(binaryOperation(a, b));\n }\n "; + } + return BinaryOpProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var CHECK_NAN_SNIPPET$1 = "\n result.r = isNaN.r > 0. ? NAN : result.r;\n result.g = isNaN.g > 0. ? NAN : result.g;\n result.b = isNaN.b > 0. ? NAN : result.b;\n result.a = isNaN.a > 0. ? NAN : result.a;\n"; + // We do the same as in ./binaryop_gpu, with vec4 and ivec4. + // On Linux, the vectorized implementation produces NaNs when a and b are 0. + var DIV$1 = "\n // vec4 one = vec4(equal(a, b));\n // return one + (vec4(1.0) - one) * a / b;\n vec4 result = a / b;\n if(a.x == b.x) {\n result.x = 1.;\n }\n if(a.y == b.y) {\n result.y = 1.;\n }\n if(a.z == b.z) {\n result.z = 1.;\n }\n if(a.w == b.w) {\n result.w = 1.;\n }\n\n return result;\n"; + var INT_DIV$1 = "\n ivec4 ia = round(a);\n ivec4 ib = round(b);\n bvec4 cond = notEqual(ib, ivec4(0));\n ivec4 result = ivec4(0);\n vec4 s = sign(a) * sign(b);\n\n // Windows (D3D) wants guaranteed non-zero int division at compile-time.\n if (cond[0]) {\n result[0] = idiv(ia[0], ib[0], s[0]);\n }\n if (cond[1]) {\n result[1] = idiv(ia[1], ib[1], s[1]);\n }\n if (cond[2]) {\n result[2] = idiv(ia[2], ib[2], s[2]);\n }\n if (cond[3]) {\n result[3] = idiv(ia[3], ib[3], s[3]);\n }\n return vec4(result);\n"; + var POW$1 = "\n // isModRound1 has 1 for components with round(mod(b, 2.0)) == 1, 0 otherwise.\n vec4 isModRound1 = vec4(equal(round(mod(b, 2.0)), ivec4(1)));\n vec4 multiplier = sign(a) * isModRound1 + (vec4(1.0) - isModRound1);\n vec4 result = multiplier * pow(abs(a), b);\n\n // Ensure that a^0 = 1, including 0^0 = 1 as this correspond to TF and JS\n bvec4 isExpZero = equal(b, vec4(0.0));\n result.r = isExpZero.r ? 1.0 : result.r;\n result.g = isExpZero.g ? 1.0 : result.g;\n result.b = isExpZero.b ? 1.0 : result.b;\n result.a = isExpZero.a ? 1.0 : result.a;\n\n vec4 isNaN = vec4(lessThan(a, vec4(0.0))) * vec4(lessThan(floor(b), b));\n " + + CHECK_NAN_SNIPPET$1 + "\n return result;\n"; + var PRELU$1 = "\n vec4 aLessThanZero = vec4(lessThan(a, vec4(0.)));\n return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a);\n"; + var ELU_DER$1 = "\n vec4 bGTEZero = vec4(greaterThanEqual(b, vec4(0.)));\n return (bGTEZero * a) + ((vec4(1.0) - bGTEZero) * (a * (b + vec4(1.0))));\n"; + var ATAN2$1 = "\n vec4 result = atan(a, b);\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + + CHECK_NAN_SNIPPET$1 + "\n return result;\n"; + var EQUAL$1 = "\n return vec4(equal(a, b));\n"; + var NOT_EQUAL$1 = "\n return vec4(notEqual(a, b));\n"; + var LESS$1 = "\n return vec4(lessThan(a, b));\n"; + var LESS_EQUAL$1 = "\n return vec4(lessThanEqual(a, b));\n"; + var GREATER$1 = "\n return vec4(greaterThan(a, b));\n"; + var GREATER_EQUAL$1 = "\n return vec4(greaterThanEqual(a, b));\n"; + var LOGICAL_AND$1 = "\n return vec4(\n vec4(greaterThanEqual(a, vec4(1.0))) *\n vec4(greaterThanEqual(b, vec4(1.0))));\n"; + var LOGICAL_OR$1 = "\n return min(\n vec4(greaterThanEqual(a, vec4(1.0))) +\n vec4(greaterThanEqual(b, vec4(1.0))),\n vec4(1.0));\n"; + var MAX$1 = "\n vec4 result = vec4(max(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + + CHECK_NAN_SNIPPET$1 + "\n return result;\n"; + var MIN$1 = "\n vec4 result = vec4(min(a, b));\n vec4 isNaN = min(vec4(isnan(a)) + vec4(isnan(b)), vec4(1.0));\n " + + CHECK_NAN_SNIPPET$1 + "\n return result;\n"; + var MOD$1 = "\n vec4 result = mod(a, b);\n vec4 isNaN = vec4(equal(b, vec4(0.0)));\n " + + CHECK_NAN_SNIPPET$1 + "\n return result;\n"; + var BinaryOpPackedProgram = /** @class */ (function () { + function BinaryOpPackedProgram(op, aShape, bShape, checkOutOfBounds) { + if (checkOutOfBounds === void 0) { checkOutOfBounds = false; } + this.variableNames = ['A', 'B']; + this.supportsBroadcasting = true; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = + assertAndGetBroadcastShape(aShape, bShape); + var rank = this.outputShape.length; + var checkOutOfBoundsString = ''; + if (checkOutOfBounds) { + if (rank === 0 || sizeFromShape(this.outputShape) === 1) { + checkOutOfBoundsString = "\n result.y = 0.;\n result.z = 0.;\n result.w = 0.;\n "; + } + else { + var dtype = getCoordsDataType(rank); + checkOutOfBoundsString = "\n " + dtype + " coords = getOutputCoords();\n "; + if (rank === 1) { + checkOutOfBoundsString += "\n result.y = (coords + 1) >= " + this.outputShape[0] + " ? 0. : result.y;\n result.z = 0.;\n result.w = 0.;\n "; + } + else { + var channels = getChannels('coords', rank); + checkOutOfBoundsString += "\n bool nextRowOutOfBounds =\n (" + channels[rank - 2] + " + 1) >= " + this.outputShape[rank - 2] + ";\n bool nextColOutOfBounds =\n (" + channels[rank - 1] + " + 1) >= " + this.outputShape[rank - 1] + ";\n result.y = nextColOutOfBounds ? 0. : result.y;\n result.z = nextRowOutOfBounds ? 0. : result.z;\n result.w = nextColOutOfBounds || nextRowOutOfBounds ? 0. : result.w;\n "; + } + } + } + this.userCode = "\n vec4 binaryOperation(vec4 a, vec4 b) {\n " + op + "\n }\n\n void main() {\n vec4 a = getAAtOutCoords();\n vec4 b = getBAtOutCoords();\n\n vec4 result = binaryOperation(a, b);\n " + checkOutOfBoundsString + "\n\n setOutput(result);\n }\n "; + } + return BinaryOpPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ClipProgram = /** @class */ (function () { + function ClipProgram(aShape) { + this.variableNames = ['A']; + this.outputShape = aShape; + this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n float value = getAAtOutCoords();\n if (isnan(value)) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, minVal, maxVal));\n }\n "; + } + ClipProgram.prototype.getCustomSetupFunc = function (min, max) { + var _this = this; + return function (gpgpu, webGLProgram) { + if (_this.minLoc == null) { + _this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal'); + _this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal'); + } + gpgpu.gl.uniform1f(_this.minLoc, min); + gpgpu.gl.uniform1f(_this.maxLoc, max); + }; + }; + return ClipProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ClipPackedProgram = /** @class */ (function () { + function ClipPackedProgram(aShape) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = aShape; + this.userCode = "\n uniform float minVal;\n uniform float maxVal;\n\n void main() {\n vec4 value = getAAtOutCoords();\n\n if (any(isnan(value))) {\n setOutput(value);\n return;\n }\n\n setOutput(clamp(value, vec4(minVal), vec4(maxVal)));\n }\n "; + } + ClipPackedProgram.prototype.getCustomSetupFunc = function (min, max) { + var _this = this; + return function (gpgpu, webGLProgram) { + if (_this.minLoc == null) { + _this.minLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'minVal'); + _this.maxLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'maxVal'); + } + gpgpu.gl.uniform1f(_this.minLoc, min); + gpgpu.gl.uniform1f(_this.maxLoc, max); + }; + }; + return ClipPackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ComplexAbsProgram = /** @class */ (function () { + function ComplexAbsProgram(shape) { + this.variableNames = ['real', 'imag']; + this.outputShape = shape; + this.userCode = "\n void main() {\n float re = abs(getRealAtOutCoords());\n float im = abs(getImagAtOutCoords());\n float mx = max(re, im);\n\n // sadly the length function in glsl is not underflow-safe\n // (at least not on Intel GPUs). So the safe solution is\n // to ensure underflow-safety in all cases.\n setOutput(\n mx == 0.0 ? 0.0 : mx * length(vec2(1, min(re, im)/mx))\n );\n }\n "; + } + return ComplexAbsProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ConcatProgram = /** @class */ (function () { + // Concats 2d tensors along axis=1. See comments in MathBackendWebGL.concat(). + function ConcatProgram(shapes) { + this.outputShape = []; + this.outputShape = computeOutShape(shapes, 1 /* axis */); + this.variableNames = shapes.map(function (_, i) { return "T" + i; }); + var offsets = new Array(shapes.length - 1); + offsets[0] = shapes[0][1]; + for (var i = 1; i < offsets.length; i++) { + offsets[i] = offsets[i - 1] + shapes[i][1]; + } + var snippets = ["if (yC < " + offsets[0] + ") setOutput(getT0(yR, yC));"]; + for (var i = 1; i < offsets.length; i++) { + var shift = offsets[i - 1]; + snippets.push("else if (yC < " + offsets[i] + ") " + + ("setOutput(getT" + i + "(yR, yC-" + shift + "));")); + } + var lastIndex = offsets.length; + var lastShift = offsets[offsets.length - 1]; + snippets.push("else setOutput(getT" + lastIndex + "(yR, yC-" + lastShift + "));"); + this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int yR = coords.x;\n int yC = coords.y;\n\n " + snippets.join('\n ') + "\n }\n "; + } + return ConcatProgram; + }()); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ConcatPackedProgram = /** @class */ (function () { + function ConcatPackedProgram(shapes, axis) { + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = []; + this.outputShape = computeOutShape(shapes, axis); + var shape = this.outputShape; + var rank = shape.length; + var dtype = getCoordsDataType(rank); + var coords = getChannels('coords', rank); + var channels = ['x', 'y', 'z', 'w', 'u', 'v'].slice(0, rank); + this.variableNames = shapes.map(function (_, i) { return "T" + i; }); + var offsets = new Array(shapes.length - 1); + offsets[0] = shapes[0][axis]; + for (var i = 1; i < offsets.length; i++) { + offsets[i] = offsets[i - 1] + shapes[i][axis]; + } + var channel = channels[axis]; + var lastChannels = channels.slice(-2); + var allChannels = channels.join(); + var getValueSnippet = "if (" + channel + " < " + offsets[0] + ") {\n return getChannel(\n getT0(" + allChannels + "), vec2(" + lastChannels.join() + "));\n }"; + for (var i = 1; i < offsets.length; i++) { + var shift_1 = offsets[i - 1]; + // Note: the >= comparison below may seem unnecessary given the check + // above but is needed to workaround branch execution issues on some + // devices. It makes all the conditions exclusive without relying on + // execution order. + getValueSnippet += "\n if (" + channel + " < " + offsets[i] + " && " + channel + " >= " + offsets[i - 1] + ") {\n return getChannel(\n getT" + i + "(" + shiftedChannels(channels, channel, shift_1) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift_1) + "));\n }"; + } + var lastIndex = offsets.length; + var shift = offsets[offsets.length - 1]; + getValueSnippet += "\n return getChannel(\n getT" + lastIndex + "(" + shiftedChannels(channels, channel, shift) + "),\n vec2(" + shiftedChannels(lastChannels, channel, shift) + "));"; + this.userCode = "\n float getValue(" + channels.map(function (x) { return 'int ' + x; }) + ") {\n " + getValueSnippet + "\n }\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n vec4 result = vec4(getValue(" + coords + "), 0., 0., 0.);\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " + 1;\n if (" + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.g = getValue(" + coords + ");\n }\n\n " + coords[rank - 2] + " = " + coords[rank - 2] + " + 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + ") {\n result.a = getValue(" + coords + ");\n }\n\n " + coords[rank - 1] + " = " + coords[rank - 1] + " - 1;\n if (" + coords[rank - 2] + " < " + shape[rank - 2] + " &&\n " + coords[rank - 1] + " < " + shape[rank - 1] + ") {\n result.b = getValue(" + coords + ");\n }\n setOutput(result);\n }\n "; + } + return ConcatPackedProgram; + }()); + /** + * Return an expression for coordinates into a vector where a given channel + * will be offset by [shift]. + * + * @param channels the channels to consider + * @param channel the channel we want shifted + * @param shift the amount to subtract from the channel. + * + * @returns a string of the form 'x, y-[shift], z' where any one channel can + * have the shift applied. + */ + function shiftedChannels(channels, channel, shift) { + var channelIdx = channels.indexOf(channel); + var res = channels.map(function (c, idx) { + if (idx === channelIdx) { + return c + " - " + shift; + } + else { + return c; + } + }); + return res.join(); + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var Conv2DDerFilterProgram = /** @class */ (function () { + function Conv2DDerFilterProgram(convInfo) { + this.variableNames = ['x', 'dy']; + this.outputShape = convInfo.filterShape; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int d2 = coords.w;\n\n // Convolve x(?, ?, d1) with dy(:, :, d2) to get dw(wR, wC, d1, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n if (" + isChannelsLast + ") {\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n } else {\n float dyValue = getDy(b, d2, yR, yC);\n float xValue = getX(b, d1, xR, xC);\n dotProd += (xValue * dyValue);\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return Conv2DDerFilterProgram; + }()); + var Conv2DDerInputProgram = /** @class */ (function () { + function Conv2DDerInputProgram(convInfo) { + this.variableNames = ['dy', 'W']; + this.outputShape = convInfo.inShape; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + var padTop = filterHeight - 1 - convInfo.padInfo.top; + var padLeft = filterWidth - 1 - convInfo.padInfo.left; + var rowDim = isChannelsLast ? 1 : 2; + var colDim = isChannelsLast ? 2 : 3; + var channelDim = isChannelsLast ? 3 : 1; + this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[" + channelDim + "];\n\n ivec2 dyCorner = ivec2(coords[" + rowDim + "], coords[" + colDim + "]) - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n\n if (" + isChannelsLast + ") {\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n } else {\n float xValue = getDy(batch, d2, idyR, idyC);\n float wValue = getW(wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return Conv2DDerInputProgram; + }()); + var Conv3DDerFilterProgram = /** @class */ (function () { + function Conv3DDerFilterProgram(convInfo) { + this.variableNames = ['x', 'dy']; + this.outputShape = convInfo.filterShape; + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var padFront = convInfo.padInfo.front; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + this.userCode = "\n void main() {\n ivec5 coords = getOutputCoords();\n int wF = coords.x;\n int wR = coords.y;\n int wC = coords.z;\n int d1 = coords.w;\n int d2 = coords.u;\n\n float dotProd = 0.0;\n\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yF = 0; yF < " + convInfo.outDepth + "; yF++) {\n int xF = wF + yF * " + strideDepth + " - " + padFront + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yF, yR, yC, d2);\n float xValue = getX(b, xF, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return Conv3DDerFilterProgram; + }()); + var Conv3DDerInputProgram = /** @class */ (function () { + function Conv3DDerInputProgram(convInfo) { + this.variableNames = ['dy', 'W']; + this.outputShape = convInfo.inShape; + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var padFront = filterDepth - 1 - convInfo.padInfo.front; + var padTop = filterHeight - 1 - convInfo.padInfo.top; + var padLeft = filterWidth - 1 - convInfo.padInfo.left; + this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d1 = coords.u;\n\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyFCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n float dyF = float(dyFCorner + wF) / " + strideDepth + ".0;\n\n if (dyF < 0.0 || dyF >= " + convInfo.outDepth + ".0 || fract(dyF) > 0.0) {\n continue;\n }\n int idyF = int(dyF);\n\n int wFPerm = " + filterDepth + " - 1 - wF;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n for (int d2 = 0; d2 < " + convInfo.outChannels + "; d2++) {\n float xValue = getDy(batch, idyF, idyR, idyC, d2);\n float wValue = getW(wFPerm, wRPerm, wCPerm, d1, d2);\n dotProd += xValue * wValue;\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return Conv3DDerInputProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DepthwiseConv2DDerFilterProgram = /** @class */ (function () { + function DepthwiseConv2DDerFilterProgram(convInfo) { + this.variableNames = ['x', 'dy']; + this.outputShape = convInfo.filterShape; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var channelMul = convInfo.outChannels / convInfo.inChannels; + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int wR = coords.x;\n int wC = coords.y;\n int d1 = coords.z;\n int dm = coords.w;\n int d2 = d1 * " + channelMul + " + dm;\n\n float dotProd = 0.0;\n\n // TO DO: Vec4 over the batch size\n for (int b = 0; b < " + convInfo.batchSize + "; b++) {\n for (int yR = 0; yR < " + convInfo.outHeight + "; yR++) {\n int xR = wR + yR * " + strideHeight + " - " + padTop + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int yC = 0; yC < " + convInfo.outWidth + "; yC++) {\n int xC = wC + yC * " + strideWidth + " - " + padLeft + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float dyValue = getDy(b, yR, yC, d2);\n float xValue = getX(b, xR, xC, d1);\n dotProd += (xValue * dyValue);\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return DepthwiseConv2DDerFilterProgram; + }()); + var DepthwiseConv2DDerInputProgram = /** @class */ (function () { + function DepthwiseConv2DDerInputProgram(convInfo) { + this.variableNames = ['dy', 'W']; + this.outputShape = convInfo.inShape; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var padTop = filterHeight - 1 - convInfo.padInfo.top; + var padLeft = filterWidth - 1 - convInfo.padInfo.left; + var channelMul = convInfo.outChannels / convInfo.inChannels; + this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d1 = coords[3];\n ivec2 dyCorner = coords.yz - pads;\n int dyRCorner = dyCorner.x;\n int dyCCorner = dyCorner.y;\n\n float dotProd = 0.0;\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n int wRPerm = " + filterHeight + " - 1 - wR;\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n int wCPerm = " + filterWidth + " - 1 - wC;\n\n // TO DO: Vec4 over the channelMul\n for (int dm = 0; dm < " + channelMul + "; dm++) {\n int d2 = d1 * " + channelMul + " + dm;\n float xValue = getDy(batch, idyR, idyC, d2);\n float wValue = getW(wRPerm, wCPerm, d1, dm);\n dotProd += xValue * wValue;\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return DepthwiseConv2DDerInputProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var Conv2DProgram = /** @class */ (function () { + function Conv2DProgram(convInfo, addBias, activation, hasPreluActivationWeights) { + if (addBias === void 0) { addBias = false; } + if (activation === void 0) { activation = null; } + if (hasPreluActivationWeights === void 0) { hasPreluActivationWeights = false; } + this.variableNames = ['x', 'W']; + this.outputShape = convInfo.outShape; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; + var inputDepthVec4Remainder = convInfo.inChannels % 4; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + var rowDim = isChannelsLast ? 1 : 2; + var colDim = isChannelsLast ? 2 : 3; + var channelDim = isChannelsLast ? 3 : 1; + var activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivationWeights) { + activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; + } + else { + activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n "; + } + applyActivationSnippet = "result = activation(result);"; + } + var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + if (hasPreluActivationWeights) { + this.variableNames.push('preluActivationWeights'); + } + this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d2 = coords[" + channelDim + "];\n\n ivec2 xRCCorner =\n ivec2(coords[" + rowDim + "], coords[" + colDim + "]) * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, d2) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 wValues = vec4(\n getW(wR, wC, d1, d2),\n getW(wR, wC, d1 + 1, d2),\n getW(wR, wC, d1 + 2, d2),\n getW(wR, wC, d1 + 3, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec4 xValues = vec4(\n getX(batch, xR, xC, d1),\n getX(batch, xR, xC, d1 + 1),\n getX(batch, xR, xC, d1 + 2),\n getX(batch, xR, xC, d1 + 3)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec4 xValues = vec4(\n getX(batch, d1, xR, xC),\n getX(batch, d1 + 1, xR, xC),\n getX(batch, d1 + 2, xR, xC),\n getX(batch, d1 + 3, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n\n if (" + isChannelsLast + ") {\n dotProd +=\n getX(batch, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else {\n dotProd +=\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC) *\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2);\n }\n\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 wValues = vec2(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec2 xValues = vec2(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec2 xValues = vec2(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 wValues = vec3(\n getW(wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n\n if (" + isChannelsLast + ") {\n vec3 xValues = vec3(\n getX(batch, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n dotProd += dot(xValues, wValues);\n } else {\n vec3 xValues = vec3(\n getX(batch, " + inputDepthNearestVec4 + ", xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 1, xR, xC),\n getX(batch, " + inputDepthNearestVec4 + " + 2, xR, xC)\n );\n dotProd += dot(xValues, wValues);\n }\n\n }\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n "; + } + return Conv2DProgram; + }()); + var Conv3DProgram = /** @class */ (function () { + function Conv3DProgram(convInfo) { + this.variableNames = ['x', 'W']; + this.outputShape = convInfo.outShape; + var padFront = convInfo.padInfo.front; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var inputDepthNearestVec4 = Math.floor(convInfo.inChannels / 4) * 4; + var inputDepthVec4Remainder = convInfo.inChannels % 4; + this.userCode = "\n const ivec3 strides = ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int d2 = coords.u;\n\n ivec3 xFRCCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xFCorner = xFRCCorner.x;\n int xRCorner = xFRCCorner.y;\n int xCCorner = xFRCCorner.z;\n\n // Convolve x(?, ?, ?, d1) with w(:, :, :, d1, d2) to get\n // y(yF, yR, yC, d2). ? = to be determined. : = across all\n // values in that axis.\n float dotProd = 0.0;\n for (int wF = 0; wF < " + filterDepth + "; wF++) {\n int xF = xFCorner + wF * " + dilationDepth + ";\n\n if (xF < 0 || xF >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n for (int d1 = 0; d1 < " + inputDepthNearestVec4 + "; d1 += 4) {\n vec4 xValues = vec4(\n getX(batch, xF, xR, xC, d1),\n getX(batch, xF, xR, xC, d1 + 1),\n getX(batch, xF, xR, xC, d1 + 2),\n getX(batch, xF, xR, xC, d1 + 3)\n );\n vec4 wValues = vec4(\n getW(wF, wR, wC, d1, d2),\n getW(wF, wR, wC, d1 + 1, d2),\n getW(wF, wR, wC, d1 + 2, d2),\n getW(wF, wR, wC, d1 + 3, d2)\n );\n\n dotProd += dot(xValues, wValues);\n }\n\n if (" + (inputDepthVec4Remainder === 1) + ") {\n dotProd +=\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + ") *\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2);\n } else if (" + (inputDepthVec4Remainder === 2) + ") {\n vec2 xValues = vec2(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1)\n );\n vec2 wValues = vec2(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2)\n );\n dotProd += dot(xValues, wValues);\n } else if (" + (inputDepthVec4Remainder === 3) + ") {\n vec3 xValues = vec3(\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + "),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 1),\n getX(batch, xF, xR, xC, " + inputDepthNearestVec4 + " + 2)\n );\n vec3 wValues = vec3(\n getW(wF, wR, wC, " + inputDepthNearestVec4 + ", d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 1, d2),\n getW(wF, wR, wC, " + inputDepthNearestVec4 + " + 2, d2)\n );\n dotProd += dot(xValues, wValues);\n }\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return Conv3DProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DepthwiseConv2DProgram = /** @class */ (function () { + function DepthwiseConv2DProgram(convInfo, addBias, activation, hasPreluActivation) { + if (addBias === void 0) { addBias = false; } + if (activation === void 0) { activation = null; } + if (hasPreluActivation === void 0) { hasPreluActivation = false; } + this.variableNames = ['x', 'W']; + this.outputShape = convInfo.outShape; + var xNumRows = convInfo.inHeight; + var xNumCols = convInfo.inWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var channelMul = convInfo.outChannels / convInfo.inChannels; + var activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivation) { + activationSnippet = "float activation(float a) {\n float b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; + } + else { + activationSnippet = "\n float activation(float x) {\n " + activation + "\n }\n "; + } + applyActivationSnippet = "result = activation(result);"; + } + var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + if (hasPreluActivation) { + this.variableNames.push('preluActivationWeights'); + } + this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2 / " + channelMul + ";\n int q = d2 - d1 * " + channelMul + ";\n\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // Convolve x(?, ?, d1) with w(:, :, d1, q) to get y(yR, yC, d2).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n // TO DO(dsmilkov): Flatten the two for loops and vec4 the operations.\n for (int wR = 0; wR < " + filterHeight + "; wR++) {\n int xR = xRCorner + wR * " + dilationHeight + ";\n\n if (xR < 0 || xR >= " + xNumRows + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidth + "; wC++) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n if (xC < 0 || xC >= " + xNumCols + ") {\n continue;\n }\n\n float xVal = getX(batch, xR, xC, d1);\n float wVal = getW(wR, wC, d1, q);\n dotProd += xVal * wVal;\n }\n }\n\n float result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n "; + } + return DepthwiseConv2DProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DepthwiseConvPacked2DProgram = /** @class */ (function () { + function DepthwiseConvPacked2DProgram(convInfo, addBias, activation, hasPreluActivation) { + if (addBias === void 0) { addBias = false; } + if (activation === void 0) { activation = null; } + if (hasPreluActivation === void 0) { hasPreluActivation = false; } + this.variableNames = ['x', 'W']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = convInfo.outShape; + var xNumRows = convInfo.inHeight; + var xNumCols = convInfo.inWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var texelsAcross = filterWidth; + var mainLoop = "int xR; int xC; int xCOffset;"; + for (var r = 0; r < filterHeight; r++) { + for (var c = 0; c < filterWidth; c++) { + mainLoop += "\n vec4 xTexelR" + r + "C" + c * 2 + " = vec4(0.);\n vec4 wR" + r + "C" + c + " = vec4(0.);\n vec4 xR" + r + "C" + c + " = vec4(0.);"; + } + } + /** + * This vectorized implementation works by gathering the values needed for + * each output channel's dot product into vec4's and then multiplying them + * all together (this happens in the final double for-loop below). Most of + * the main loop consists of constructing these vec4's with the minimum + * number of texture2D calls, which means making use of all four returned + * values from a texture2D call at once. + */ + for (var r = 0; r < filterHeight; r++) { + for (var texelC = 0; texelC < texelsAcross; texelC++) { + var c = texelC * 2; + mainLoop += "\n xR = xRCorner + " + r * dilationHeight + ";\n xC = xCCorner + " + c * dilationWidth + ";\n "; + if (strideWidth === 1) { + if (c < filterWidth) { + // If padding is odd, the outer texels have to be composed. + if (padLeft % 2 === 1) { + // TODO: Ensure vec4 previous does not result in redundant sample, + // and avoid setting xTexelRC's that exceed the boundary in the + // first place rather than resetting them to vec4(0)). + // To compute xCOffset: + // - If padding is odd, we must add 1 to ensure we ask for an + // even-numbered row. + // - We subtract 2 to access the previous texel. + mainLoop += "\n xCOffset = xC + 1;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + 1 - 2;\n if(xR >= 0 && xR < " + xNumRows + " && xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n vec4 previous = getX(batch, xR, xCOffset, d1);\n xR" + r + "C" + c + " = vec4(previous.zw, xTexelR" + r + "C" + c + ".xy);\n } else {\n xR" + r + "C" + c + " = vec4(0, 0, xTexelR" + r + "C" + c + ".xy);\n }\n "; + } + else { + // Padding is even, so xRC corresponds to a single texel. + mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + " && xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = xTexelR" + r + "C" + c + ";\n "; + } + if (c + 1 < filterWidth) { + // If dilation is even, the second entry should match the first + // (either both are composed or both are single samples). But if + // dilation is odd, then the second entry should be the opposite + // of the first (if the first is composed, the second is a single + // sample, and vice versa.) + var nextTexelOffset = padLeft % 2 === 0 ? + nearestLargerEven(dilationWidth) : + dilationWidth; + if ((dilationWidth % 2 === 0 && padLeft % 2 === 1) || + (dilationWidth % 2 !== 0 && padLeft % 2 !== 1)) { + mainLoop += "\n xCOffset = xC + " + padLeft % 2 + " + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n "; + // If dilation > 1 then the xRC's will not be able to share any + // values, so each xRC will require two unique calls to getX. + if (dilationWidth > 1) { + mainLoop += "\n xCOffset -= 2;\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n "; + } + mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".xy);\n "; + } + else { + mainLoop += "\n xCOffset = xC + " + nextTexelOffset + ";\n\n if(xR >= 0 && xR < " + xNumRows + " &&\n xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n }\n\n xR" + r + "C" + (c + 1) + " = xTexelR" + r + "C" + (c + 2) + ";\n "; + } + } + } + } + else { // stride > 1 + if (c < filterWidth) { + mainLoop += "\n if(xR >= 0 && xR < " + xNumRows + ") {\n "; + // Depending on whether padLeft is even or odd, we want either the + // xy or zw channels from X texels for xR${r}C${c}. If padLeft is + // even, xR${r}C${c + 1} is simply the zw channels of texels we've + // already sampled. But if padLeft is odd, xR${r}C{$c + 1}.zw will + // need to come from the xy channels of a new texel, hence the `vec4 + // final` initialized below. + if (padLeft % 2 === 1) { + mainLoop += "\n xCOffset = xC + 1 - " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n if(xC + 1 >= 0 && xC + 1 < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xC + 1, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n "; + if (c + 1 < filterWidth) { + mainLoop += "\n vec4 final = vec4(0.);\n xCOffset = xC + 1 + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n final = getX(batch, xR, xCOffset, d1);\n }\n xR" + r + "C" + (c + 1) + " = vec4(xTexelR" + r + "C" + (c + 2) + ".xy, final.xy);\n "; + } + } + else { + mainLoop += "\n if(xC >= 0 && xC < " + xNumCols + ") {\n xTexelR" + r + "C" + c + " = getX(batch, xR, xC, d1);\n } else {\n xTexelR" + r + "C" + c + " = vec4(0.);\n }\n\n xCOffset = xC + " + strideWidth + ";\n if(xCOffset >= 0 && xCOffset < " + xNumCols + ") {\n xTexelR" + r + "C" + (c + 2) + " = getX(batch, xR, xCOffset, d1);\n } else {\n xTexelR" + r + "C" + (c + 2) + " = vec4(0.);\n }\n\n xR" + r + "C" + c + " = vec4(\n xTexelR" + r + "C" + c + ".xy, xTexelR" + r + "C" + (c + 2) + ".xy);\n "; + if (c + 1 < filterWidth) { + mainLoop += "\n xR" + r + "C" + (c + 1) + " = vec4(\n xTexelR" + r + "C" + c + ".zw, xTexelR" + r + "C" + (c + 2) + ".zw);\n "; + } + } + mainLoop += "}"; + } + } + if (c < filterWidth) { + mainLoop += "\n vec4 wTexelR" + r + "C" + c + " = getW(" + r + ", " + c + ", d1, q);\n wR" + r + "C" + c + " = vec4(wTexelR" + r + "C" + c + ".xz, wTexelR" + r + "C" + c + ".xz);\n "; + if (c + 1 < filterWidth) { + mainLoop += "\n vec4 wTexelR" + r + "C" + (c + 1) + " = getW(" + r + ", " + (c + 1) + ", d1, q);\n wR" + r + "C" + (c + 1) + " =\n vec4(wTexelR" + r + "C" + (c + 1) + ".xz, wTexelR" + r + "C" + (c + 1) + ".xz);"; + } + } + } + } + for (var r = 0; r < filterHeight; r++) { + for (var c = 0; c < filterWidth; c++) { + mainLoop += "dotProd += xR" + r + "C" + c + " * wR" + r + "C" + c + ";"; + } + } + var activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivation) { + activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; + } + else { + activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }"; + } + applyActivationSnippet = "result = activation(result);"; + } + var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + if (hasPreluActivation) { + this.variableNames.push('preluActivationWeights'); + } + this.userCode = "\n " + activationSnippet + "\n\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n\n ivec4 coords = getOutputCoords();\n int batch = coords.x;\n ivec2 xRCCorner = coords.yz * strides - pads;\n int d2 = coords.w;\n int d1 = d2;\n int q = 0;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n vec4 dotProd = vec4(0.);\n\n " + mainLoop + "\n\n vec4 result = dotProd;\n " + addBiasSnippet + "\n " + applyActivationSnippet + "\n setOutput(result);\n }\n "; + } + return DepthwiseConvPacked2DProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var CropAndResizeProgram = /** @class */ (function () { + function CropAndResizeProgram(imageShape, boxShape, cropSize, method, extrapolationValue) { + this.variableNames = ['Image', 'Boxes', 'BoxInd']; + this.outputShape = []; + var batch = imageShape[0], imageHeight = imageShape[1], imageWidth = imageShape[2], depth = imageShape[3]; + var numBoxes = boxShape[0]; + var cropHeight = cropSize[0], cropWidth = cropSize[1]; + this.outputShape = [numBoxes, cropHeight, cropWidth, depth]; + var methodId = method === 'bilinear' ? 1 : 0; + var _a = [imageHeight - 1 + ".0", imageWidth - 1 + ".0"], inputHeightFloat = _a[0], inputWidthFloat = _a[1]; + var _b = cropHeight > 1 ? + [ + "" + (imageHeight - 1) / (cropHeight - 1), + '(y2-y1) * height_ratio', + "y1*" + inputHeightFloat + " + float(y)*(height_scale)", + ] : + [ + '0.0', + '0.0', + "0.5 * (y1+y2) * " + inputHeightFloat, + ], heightRatio = _b[0], heightScale = _b[1], inY = _b[2]; + var _c = cropWidth > 1 ? + [ + "" + (imageWidth - 1) / (cropWidth - 1), + '(x2-x1) * width_ratio', + "x1*" + inputWidthFloat + " + float(x)*(width_scale)", + ] : + [ + '0.0', + '0.0', + "0.5 * (x1+x2) * " + inputWidthFloat, + ], widthRatio = _c[0], widthScale = _c[1], inX = _c[2]; + // Reference implementation + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op_gpu.cu.cc + this.userCode = "\n const float height_ratio = float(" + heightRatio + ");\n const float width_ratio = float(" + widthRatio + ");\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int y = coords[1];\n int x = coords[2];\n int d = coords[3];\n\n // get box vals\n float y1 = getBoxes(b,0);\n float x1 = getBoxes(b,1);\n float y2 = getBoxes(b,2);\n float x2 = getBoxes(b,3);\n\n // get image in batch index\n int bInd = round(getBoxInd(b));\n if(bInd < 0 || bInd >= " + batch + ") {\n return;\n }\n\n float height_scale = " + heightScale + ";\n float width_scale = " + widthScale + ";\n\n float in_y = " + inY + ";\n if( in_y < 0.0 || in_y > " + inputHeightFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n float in_x = " + inX + ";\n if( in_x < 0.0 || in_x > " + inputWidthFloat + " ) {\n setOutput(float(" + extrapolationValue + "));\n return;\n }\n\n vec2 sourceFracIndexCR = vec2(in_x,in_y);\n if(" + methodId + " == 1) {\n // Compute the four integer indices.\n ivec2 sourceFloorCR = ivec2(sourceFracIndexCR);\n ivec2 sourceCeilCR = ivec2(ceil(sourceFracIndexCR));\n\n float topLeft = getImage(b, sourceFloorCR.y, sourceFloorCR.x, d);\n float bottomLeft = getImage(b, sourceCeilCR.y, sourceFloorCR.x, d);\n float topRight = getImage(b, sourceFloorCR.y, sourceCeilCR.x, d);\n float bottomRight = getImage(b, sourceCeilCR.y, sourceCeilCR.x, d);\n\n vec2 fracCR = sourceFracIndexCR - vec2(sourceFloorCR);\n\n float top = topLeft + (topRight - topLeft) * fracCR.x;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracCR.x;\n float newValue = top + (bottom - top) * fracCR.y;\n setOutput(newValue);\n } else {\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestCR = ivec2(floor(\n sourceFracIndexCR + vec2(0.5,0.5)));\n float newValue = getImage(b, sourceNearestCR.y, sourceNearestCR.x, d);\n setOutput(newValue);\n }\n }\n "; + } + return CropAndResizeProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var CumSumProgram = /** @class */ (function () { + function CumSumProgram(shape, exclusive, reverse) { + this.variableNames = ['x']; + this.outputShape = shape; + var rank = shape.length; + var finalDim = shape[shape.length - 1]; + var comparator = reverse ? '<' : '>'; + this.userCode = "\n int getIndex(int i) {\n " + (reverse ? "return " + finalDim + " -i - 1;" : 'return i;') + "\n }\n\n void main() {\n " + getCoordsDataType(rank) + " coords = getOutputCoords();\n int end = " + getFinalCoord(rank, 'coords') + ";\n float val = 0.0;\n for (int i = " + finalDim + " - 1; i >= 0; i -= 1) {\n int idx = getIndex(i);\n if (idx " + comparator + " end) {\n continue;\n }\n if (idx == end && " + exclusive + ") {\n continue;\n }\n " + getFinalCoord(rank, 'coords') + " = idx;\n val += getX(" + getCoords(rank, 'coords') + ");\n }\n setOutput(val);\n }\n "; + } + return CumSumProgram; + }()); + function getCoords(rank, name) { + if (rank === 1) { + return "" + name; + } + else if (rank === 2) { + return name + ".x, " + name + ".y"; + } + else if (rank === 3) { + return name + ".x, " + name + ".y, " + name + ".z"; + } + else if (rank === 4) { + return name + ".x, " + name + ".y, " + name + ".z, " + name + ".w"; + } + else { + throw Error("Cumulative sum for rank " + rank + " is not yet supported"); + } + } + function getFinalCoord(rank, name) { + if (rank === 1) { + return "" + name; + } + else if (rank === 2) { + return name + ".y"; + } + else if (rank === 3) { + return name + ".z"; + } + else if (rank === 4) { + return name + ".w"; + } + else { + throw Error("Cumulative sum for rank " + rank + " is not yet supported"); + } + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DecodeMatrixProgram = /** @class */ (function () { + function DecodeMatrixProgram(outputShape) { + this.variableNames = ['A']; + this.packedInputs = false; + this.packedOutput = true; + this.outPackingScheme = PackingScheme.DENSE; + var texShape = getDenseTexShape(outputShape); + var glsl = getGlslDifferences(); + this.outputShape = outputShape; + this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getA(rc.x, rc.y, rc.z);\n }\n\n " + glsl.output + " = result;\n }\n "; + } + return DecodeMatrixProgram; + }()); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DecodeMatrixPackedProgram = /** @class */ (function () { + function DecodeMatrixPackedProgram(outputShape) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + this.outPackingScheme = PackingScheme.DENSE; + var texShape = getDenseTexShape(outputShape); + var glsl = getGlslDifferences(); + this.outputShape = outputShape; + this.userCode = "\n ivec3 outCoordsFromFlatIndex(int index) {\n " + getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], outputShape) + "\n return ivec3(r, c, d);\n }\n\n void main() {\n ivec2 resTexRC = ivec2(resultUV.yx *\n vec2(" + texShape[0] + ", " + texShape[1] + "));\n int index = 4 * (resTexRC.x * " + texShape[1] + " + resTexRC.y);\n\n vec4 result = vec4(0.);\n\n for (int i=0; i<4; i++) {\n int flatIndex = index + i;\n ivec3 rc = outCoordsFromFlatIndex(flatIndex);\n result[i] = getChannel(getA(rc.x, rc.y, rc.z), vec2(rc.y, rc.z));\n }\n\n " + glsl.output + " = result;\n }\n "; + } + return DecodeMatrixPackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DepthToSpaceProgram = /** @class */ (function () { + function DepthToSpaceProgram(outputShape, blockSize, dataFormat) { + this.variableNames = ['x']; + this.outputShape = []; + this.outputShape = outputShape; + this.blockSize = blockSize; + this.dataFormat = dataFormat; + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int h = " + this.getHeightCoordString() + ";\n int w = " + this.getWidthCoordString() + ";\n int d = " + this.getDepthCoordString() + ";\n\n int in_h = h / " + blockSize + ";\n int offset_h = imod(h, " + blockSize + ");\n int in_w = w / " + blockSize + ";\n int offset_w = imod(w, " + blockSize + ");\n int offset_d = (offset_h * " + blockSize + " + offset_w) *\n " + this.getOutputDepthSize() + ";\n int in_d = d + offset_d;\n\n float result = " + this.getInputSamplingString() + ";\n setOutput(result);\n }\n "; + } + DepthToSpaceProgram.prototype.getHeightCoordString = function () { + if (this.dataFormat === 'NHWC') { + return "coords[1]"; + } + else { + return "coords[2]"; + } + }; + DepthToSpaceProgram.prototype.getWidthCoordString = function () { + if (this.dataFormat === 'NHWC') { + return "coords[2]"; + } + else { + return "coords[3]"; + } + }; + DepthToSpaceProgram.prototype.getDepthCoordString = function () { + if (this.dataFormat === 'NHWC') { + return "coords[3]"; + } + else { + return "coords[1]"; + } + }; + DepthToSpaceProgram.prototype.getOutputDepthSize = function () { + if (this.dataFormat === 'NHWC') { + return this.outputShape[3]; + } + else { + return this.outputShape[1]; + } + }; + DepthToSpaceProgram.prototype.getInputSamplingString = function () { + if (this.dataFormat === 'NHWC') { + return "getX(b, in_h, in_w, in_d)"; + } + else { + return "getX(b, in_d, in_h, in_w)"; + } + }; + return DepthToSpaceProgram; + }()); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DiagProgram = /** @class */ (function () { + function DiagProgram(size) { + this.variableNames = ['X']; + this.outputShape = [size, size]; + this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n float val = coords[0] == coords[1] ? getX(coords[0]) : 0.0;\n setOutput(val);\n }\n "; + } + return DiagProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EncodeFloatProgram = /** @class */ (function () { + function EncodeFloatProgram(outputShape) { + this.variableNames = ['A']; + this.outTexUsage = TextureUsage.DOWNLOAD; + var glsl = getGlslDifferences(); + this.outputShape = outputShape; + this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n float x = getAAtOutCoords();\n " + glsl.output + " = encode_float(x);\n }\n "; + } + return EncodeFloatProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EncodeFloatPackedProgram = /** @class */ (function () { + function EncodeFloatPackedProgram(outputShape) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = false; + this.outTexUsage = TextureUsage.DOWNLOAD; + var glsl = getGlslDifferences(); + this.outputShape = outputShape; + this.userCode = "\n " + ENCODE_FLOAT_SNIPPET + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n float x = getChannel(getAAtOutCoords(), vec2(coords.y, coords.z));\n " + glsl.output + " = encode_float(x);\n }\n "; + } + return EncodeFloatPackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var EncodeMatrixProgram = /** @class */ (function () { + function EncodeMatrixProgram(outputShape, texShape, inputIsUnsignedByte) { + if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; } + this.variableNames = ['A']; + var glsl = getGlslDifferences(); + var height = texShape[0], width = texShape[1]; + this.outputShape = outputShape; + var output = "result"; + if (inputIsUnsignedByte) { + output = "floor(result * 255. + 0.5)"; + } + this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n int flatIndex = getFlatIndex(coords);\n int offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n \n int r = flatIndex / " + width + ";\n int c = imod(flatIndex, " + width + ");\n vec2 uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n\n float result;\n\n if(offset == 0) {\n result = values[0];\n } else if(offset == 1) {\n result = values[1];\n } else if(offset == 2) {\n result = values[2];\n } else {\n result = values[3];\n }\n\n " + glsl.output + " = vec4(" + output + ", 0., 0., 0.);\n }\n "; + } + return EncodeMatrixProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /* + This is how the shader encodes a tensor with shape = [2, 3, 5] + (indices are [batch, row, col]). + + 000|001 002|003 004|xxx 020|021 022|023 024|xxx + ------- ------- ------- ------- ------- ------- + 010|011 012|013 014|xxx xxx|xxx xxx|xxx xxx|xxx + + 100|101 102|103 104|xxx 120|121 122|123 124|xxx + ------- ------- ------- ------- ------- ------- + 110|111 112|113 114|xxx xxx|xxx xxx|xxx xxx|xxx + + Single texels contain only values from the same batch, and from adjacent rows + and columns. + */ + var EncodeMatrixPackedProgram = /** @class */ (function () { + function EncodeMatrixPackedProgram(outputShape, texShape, inputIsUnsignedByte) { + if (inputIsUnsignedByte === void 0) { inputIsUnsignedByte = false; } + this.variableNames = ['A']; + this.packedInputs = false; + this.packedOutput = true; + var glsl = getGlslDifferences(); + var height = texShape[0], width = texShape[1]; + this.outputShape = outputShape; + var mainLoop = ''; + var output = 'result'; + if (inputIsUnsignedByte) { + output = 'floor(result * 255. + 0.5)'; + } + for (var row = 0; row <= 1; row++) { + for (var col = 0; col <= 1; col++) { + var channel = row * 2 + col; + mainLoop += "\n localCoords = coords;\n if(localCoords[2] + " + col + " < " + outputShape[2] + ") {\n localCoords[2] += " + col + ";\n if(localCoords[1] + " + row + " < " + outputShape[1] + ") {\n localCoords[1] += " + row + ";\n\n flatIndex = getFlatIndex(localCoords);\n offset = imod(flatIndex, 4);\n\n flatIndex = idiv(flatIndex, 4, 1.);\n\n r = flatIndex / " + width + ";\n c = imod(flatIndex, " + width + ");\n uv = (vec2(c, r) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n values = " + glsl.texture2D + "(A, uv);\n\n if(offset == 0) {\n result[" + channel + "] = values[0];\n } else if(offset == 1) {\n result[" + channel + "] = values[1];\n } else if(offset == 2) {\n result[" + channel + "] = values[2];\n } else {\n result[" + channel + "] = values[3];\n }\n }\n }\n "; + } + } + this.userCode = "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 coords = getOutputCoords();\n\n vec4 result = vec4(0.);\n int flatIndex, r, c, offset;\n ivec3 localCoords;\n vec2 uv;\n vec4 values;\n\n " + mainLoop + "\n\n " + glsl.output + " = " + output + ";\n }\n "; + } + return EncodeMatrixPackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var COMPLEX_FFT = { + REAL: 'return real * expR - imag * expI;', + IMAG: 'return real * expI + imag * expR;' + }; + var FFTProgram = /** @class */ (function () { + function FFTProgram(op, inputShape, inverse) { + this.variableNames = ['real', 'imag']; + var innerDim = inputShape[1]; + this.outputShape = inputShape; + var exponentMultiplierSnippet = inverse ? "2.0 * " + Math.PI : "-2.0 * " + Math.PI; + var resultDenominator = inverse ? innerDim + ".0" : '1.0'; + this.userCode = "\n const float exponentMultiplier = " + exponentMultiplierSnippet + ";\n\n float unaryOpComplex(float real, float expR, float imag, float expI) {\n " + op + "\n }\n\n float mulMatDFT(int batch, int index) {\n float indexRatio = float(index) / float(" + innerDim + ");\n float exponentMultiplierTimesIndexRatio =\n exponentMultiplier * indexRatio;\n\n float result = 0.0;\n\n for (int i = 0; i < " + innerDim + "; i++) {\n // x = (-2|2 * PI / N) * index * i;\n float x = exponentMultiplierTimesIndexRatio * float(i);\n float expR = cos(x);\n float expI = sin(x);\n float real = getReal(batch, i);\n float imag = getImag(batch, i);\n\n result +=\n unaryOpComplex(real, expR, imag, expI) / " + resultDenominator + ";\n }\n\n return result;\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n setOutput(mulMatDFT(coords[0], coords[1]));\n }\n "; + } + return FFTProgram; + }()); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var FillProgram = /** @class */ (function () { + function FillProgram(shape, value) { + this.outputShape = []; + this.variableNames = ['x']; + this.outputShape = shape; + this.userCode = "\n uniform float value;\n void main() {\n // Input can be obtained from uniform value.\n setOutput(value);\n }\n "; + } + FillProgram.prototype.getCustomSetupFunc = function (value) { + var _this = this; + return function (gpgpu, webGLProgram) { + if (_this.valueLoc == null) { + _this.valueLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'value'); + } + gpgpu.gl.uniform1f(_this.valueLoc, value); + }; + }; + return FillProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var FromPixelsProgram = /** @class */ (function () { + function FromPixelsProgram(outputShape) { + this.variableNames = ['A']; + var glsl = getGlslDifferences(); + var height = outputShape[0], width = outputShape[1]; + this.outputShape = outputShape; + this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n vec2 uv = (vec2(texC, texR) + halfCR) / vec2(" + width + ".0, " + height + ".0);\n\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n setOutput(floor(value * 255.0 + 0.5));\n }\n "; + } + return FromPixelsProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var FromPixelsPackedProgram = /** @class */ (function () { + function FromPixelsPackedProgram(outputShape) { + this.variableNames = ['A']; + this.packedInputs = false; + this.packedOutput = true; + var glsl = getGlslDifferences(); + var height = outputShape[0], width = outputShape[1]; + this.outputShape = outputShape; + this.userCode = "\n void main() {\n ivec3 coords = getOutputCoords();\n int texR = coords[0];\n int texC = coords[1];\n int depth = coords[2];\n\n vec4 result = vec4(0.);\n\n for(int row=0; row<=1; row++) {\n for(int col=0; col<=1; col++) {\n texC = coords[1] + row;\n depth = coords[2] + col;\n\n vec2 uv = (vec2(texC, texR) + halfCR) /\n vec2(" + width + ".0, " + height + ".0);\n vec4 values = " + glsl.texture2D + "(A, uv);\n float value;\n if (depth == 0) {\n value = values.r;\n } else if (depth == 1) {\n value = values.g;\n } else if (depth == 2) {\n value = values.b;\n } else if (depth == 3) {\n value = values.a;\n }\n\n result[row * 2 + col] = floor(value * 255.0 + 0.5);\n }\n }\n\n " + glsl.output + " = result;\n }\n "; + } + return FromPixelsPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var GatherProgram = /** @class */ (function () { + function GatherProgram(aShape, indicesLength, axis) { + this.variableNames = ['A', 'indices']; + var outputShape = aShape.slice(); + outputShape[axis] = indicesLength; + this.outputShape = outputShape; + this.rank = outputShape.length; + var dtype = getCoordsDataType(this.rank); + var sourceCoords = getSourceCoords$1(aShape, axis); + this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n "; + } + return GatherProgram; + }()); + function getSourceCoords$1(aShape, axis) { + var rank = aShape.length; + if (rank > 4) { + throw Error("Gather for rank " + rank + " is not yet supported"); + } + if (rank === 1) { + return "int(getIndices(resRC))"; + } + var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; + var sourceCoords = []; + for (var i = 0; i < aShape.length; i++) { + if (i === axis) { + sourceCoords.push("int(getIndices(" + currentCoords[i] + "))"); + } + else { + sourceCoords.push("" + currentCoords[i]); + } + } + return sourceCoords.join(); + } + + var GatherNDProgram = /** @class */ (function () { + function GatherNDProgram(sliceDim, strides, shape) { + this.sliceDim = sliceDim; + this.strides = strides; + this.variableNames = ['x', 'indices']; + this.outputShape = shape; + var stridesType = getCoordsDataType(strides.length); + var dtype = getCoordsDataType(shape.length); + var strideString = this.sliceDim > 1 ? 'strides[j]' : 'strides'; + this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + this.strides + ");\n void main() {\n " + dtype + " coords = getOutputCoords();\n int flattenIndex = 0;\n for (int j = 0; j < " + this.sliceDim + "; j++) {\n int index = round(getIndices(coords[0], j));\n flattenIndex += index * " + strideString + ";\n }\n setOutput(getX(flattenIndex, coords[1]));\n }\n "; + } + return GatherNDProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function createVertexShader$1(gl, debug) { + var glsl = getGlslDifferences(); + var vertexShaderSource = glsl.version + "\n precision highp float;\n " + glsl.attribute + " vec3 clipSpacePos;\n " + glsl.attribute + " vec2 uv;\n " + glsl.varyingVs + " vec2 resultUV;\n\n void main() {\n gl_Position = vec4(clipSpacePos, 1);\n resultUV = uv;\n }"; + return createVertexShader(gl, debug, vertexShaderSource); + } + function createVertexBuffer(gl, debug) { + // [x y z u v] * [upper-left, lower-left, upper-right, lower-right] + var vertexArray = new Float32Array([-1, 1, 0, 0, 1, -1, -1, 0, 0, 0, 1, 1, 0, 1, 1, 1, -1, 0, 1, 0]); + return createStaticVertexBuffer(gl, debug, vertexArray); + } + function createIndexBuffer(gl, debug) { + // OpenGL (and WebGL) have "CCW == front" winding + var triangleVertexIndices = new Uint16Array([0, 1, 2, 2, 1, 3]); + return createStaticIndexBuffer(gl, debug, triangleVertexIndices); + } + function createAndConfigureTexture(gl, debug, width, height, internalFormat, textureFormat, textureType) { + validateTextureSize(width, height); + var texture = createTexture(gl, debug); + var tex2d = gl.TEXTURE_2D; + callAndCheck(gl, debug, function () { return gl.bindTexture(tex2d, texture); }); + callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE); }); + callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE); }); + callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MIN_FILTER, gl.NEAREST); }); + callAndCheck(gl, debug, function () { return gl.texParameteri(tex2d, gl.TEXTURE_MAG_FILTER, gl.NEAREST); }); + callAndCheck(gl, debug, function () { return gl.texImage2D(tex2d, 0, internalFormat, width, height, 0, textureFormat, textureType, null); }); + callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); + return texture; + } + function createFloat32MatrixTexture(gl, debug, rows, columns, textureConfig) { + var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; + return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatFloat, textureConfig.textureFormatFloat, gl.FLOAT); + } + function createFloat16MatrixTexture(gl, debug, rows, columns, textureConfig) { + var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; + return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatHalfFloat, textureConfig.textureFormatFloat, textureConfig.textureTypeHalfFloat); + } + function createUnsignedBytesMatrixTexture(gl, debug, rows, columns, textureConfig) { + var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; + return createAndConfigureTexture(gl, debug, width, height, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE); + } + function createPackedMatrixTexture(gl, debug, rows, columns, textureConfig) { + var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; + return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatPackedFloat, gl.RGBA, gl.FLOAT); + } + function createFloat16PackedMatrixTexture(gl, debug, rows, columns, textureConfig) { + var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; + return createAndConfigureTexture(gl, debug, width, height, textureConfig.internalFormatPackedHalfFloat, gl.RGBA, textureConfig.textureTypeHalfFloat); + } + function bindVertexProgramAttributeStreams(gl, debug, program, vertexBuffer) { + var posOffset = 0; // x is the first buffer element + var uvOffset = 3 * 4; // uv comes after [x y z] + var stride = (3 * 4) + (2 * 4); // xyz + uv, each entry is 4-byte float. + callAndCheck(gl, debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, vertexBuffer); }); + var success = bindVertexBufferToProgramAttribute(gl, debug, program, 'clipSpacePos', vertexBuffer, 3, stride, posOffset); + return success && + bindVertexBufferToProgramAttribute(gl, debug, program, 'uv', vertexBuffer, 2, stride, uvOffset); + } + function uploadDenseMatrixToTexture(gl, debug, texture, width, height, data, textureConfig) { + callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); + var dataForUpload, texelDataType, internalFormat; + if (data instanceof Uint8Array) { + dataForUpload = new Uint8Array(width * height * 4); + texelDataType = gl.UNSIGNED_BYTE; + internalFormat = gl.RGBA; + } + else { + dataForUpload = new Float32Array(width * height * 4); + texelDataType = gl.FLOAT; + internalFormat = textureConfig.internalFormatPackedFloat; + } + dataForUpload.set(data); + callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, internalFormat, width, height, 0, gl.RGBA, texelDataType, dataForUpload); }); + callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); + } + function uploadPixelDataToTexture(gl, debug, texture, pixels) { + callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, texture); }); + if (pixels.data instanceof Uint8Array) { + callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, pixels.width, pixels.height, 0, gl.RGBA, gl.UNSIGNED_BYTE, pixels.data); }); + } + else { + callAndCheck(gl, debug, function () { return gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA, gl.RGBA, gl.UNSIGNED_BYTE, pixels); }); + } + callAndCheck(gl, debug, function () { return gl.bindTexture(gl.TEXTURE_2D, null); }); + } + function createBufferFromOutputTexture(gl2, debug, rows, columns, textureConfig) { + // Create and bind the buffer. + var buffer = gl2.createBuffer(); + callAndCheck(gl2, debug, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); }); + // Initialize the buffer to the size of the texture in bytes. + var bytesPerFloat = 4; + var valuesPerTexel = 4; + var bufferSizeBytes = bytesPerFloat * valuesPerTexel * rows * columns; + callAndCheck(gl2, debug, function () { return gl2.bufferData(gl2.PIXEL_PACK_BUFFER, bufferSizeBytes, gl2.STREAM_READ); }); + // Enqueue a command on the GPU command queue to copy of texture into the + // buffer. + callAndCheck(gl2, debug, function () { return gl2.readPixels(0, 0, columns, rows, gl2.RGBA, gl2.FLOAT, 0); }); + callAndCheck(gl2, debug, function () { return gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); }); + return buffer; + } + function downloadFloat32MatrixFromBuffer(gl, buffer, size) { + var gl2 = gl; + var downloadTarget = new Float32Array(size); + gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); + gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); + gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); + return downloadTarget; + } + function downloadByteEncodedFloatMatrixFromOutputTexture(gl, debug, rows, columns, textureConfig) { + var _a = getUnpackedMatrixTextureShapeWidthHeight(rows, columns), w = _a[0], h = _a[1]; + var numChannels = 4; + var downloadTarget = new Uint8Array(getUnpackedArraySizeFromMatrixSize(rows * columns, numChannels)); + callAndCheck(gl, debug, function () { return gl.readPixels(0, 0, w, h, textureConfig.downloadTextureFormat, gl.UNSIGNED_BYTE, downloadTarget); }); + // By wrapping the buffer in a Float32Array, we use native browser IEEE 754 + // decoding of the 4 bytes that back each 32 bit float. + return new Float32Array(downloadTarget.buffer); + } + function downloadPackedMatrixFromBuffer(gl, buffer, batch, rows, cols, physicalRows, physicalCols, textureConfig) { + var gl2 = gl; + var downloadTarget = new Float32Array(getPackedRGBAArraySizeFromMatrixShape(physicalRows, physicalCols)); + gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, buffer); + gl2.getBufferSubData(gl2.PIXEL_PACK_BUFFER, 0, downloadTarget); + gl2.bindBuffer(gl2.PIXEL_PACK_BUFFER, null); + return downloadTarget; + } + function downloadMatrixFromPackedOutputTexture(gl, debug, physicalRows, physicalCols) { + var packedRGBA = new Float32Array(physicalRows * physicalCols * 4); + callAndCheck(gl, debug, function () { return gl.readPixels(0, 0, physicalCols, physicalRows, gl.RGBA, gl.FLOAT, packedRGBA); }); + return packedRGBA; + } + + var gpgpu_util = /*#__PURE__*/Object.freeze({ + createVertexShader: createVertexShader$1, + createVertexBuffer: createVertexBuffer, + createIndexBuffer: createIndexBuffer, + createFloat32MatrixTexture: createFloat32MatrixTexture, + createFloat16MatrixTexture: createFloat16MatrixTexture, + createUnsignedBytesMatrixTexture: createUnsignedBytesMatrixTexture, + createPackedMatrixTexture: createPackedMatrixTexture, + createFloat16PackedMatrixTexture: createFloat16PackedMatrixTexture, + bindVertexProgramAttributeStreams: bindVertexProgramAttributeStreams, + uploadDenseMatrixToTexture: uploadDenseMatrixToTexture, + uploadPixelDataToTexture: uploadPixelDataToTexture, + createBufferFromOutputTexture: createBufferFromOutputTexture, + downloadFloat32MatrixFromBuffer: downloadFloat32MatrixFromBuffer, + downloadByteEncodedFloatMatrixFromOutputTexture: downloadByteEncodedFloatMatrixFromOutputTexture, + downloadPackedMatrixFromBuffer: downloadPackedMatrixFromBuffer, + downloadMatrixFromPackedOutputTexture: downloadMatrixFromPackedOutputTexture + }); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var GPGPUContext = /** @class */ (function () { + function GPGPUContext(gl) { + this.outputTexture = null; + this.program = null; + this.disposed = false; + this.vertexAttrsAreBound = false; + this.itemsToPoll = []; + var glVersion = env().getNumber('WEBGL_VERSION'); + if (gl != null) { + this.gl = gl; + setWebGLContext(glVersion, gl); + } + else { + this.gl = getWebGLContext(glVersion); + } + // WebGL 2.0 enables texture floats without an extension. + var COLOR_BUFFER_FLOAT = 'WEBGL_color_buffer_float'; + var COLOR_BUFFER_HALF_FLOAT = 'EXT_color_buffer_half_float'; + if (env().getNumber('WEBGL_VERSION') === 1) { + var TEXTURE_FLOAT = 'OES_texture_float'; + var TEXTURE_HALF_FLOAT = 'OES_texture_half_float'; + this.textureFloatExtension = + getExtensionOrThrow(this.gl, this.debug, TEXTURE_FLOAT); + if (hasExtension(this.gl, TEXTURE_HALF_FLOAT)) { + this.textureHalfFloatExtension = getExtensionOrThrow(this.gl, this.debug, TEXTURE_HALF_FLOAT); + } + else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { + throw new Error('GL context does not support half float textures, yet the ' + + 'environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); + } + this.colorBufferFloatExtension = this.gl.getExtension(COLOR_BUFFER_FLOAT); + if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { + this.colorBufferHalfFloatExtension = getExtensionOrThrow(this.gl, this.debug, COLOR_BUFFER_HALF_FLOAT); + } + else if (env().get('WEBGL_FORCE_F16_TEXTURES')) { + throw new Error('GL context does not support color renderable half floats, yet ' + + 'the environment flag WEBGL_FORCE_F16_TEXTURES is set to true.'); + } + } + else { + COLOR_BUFFER_FLOAT = 'EXT_color_buffer_float'; + if (hasExtension(this.gl, COLOR_BUFFER_FLOAT)) { + this.colorBufferFloatExtension = + this.gl.getExtension(COLOR_BUFFER_FLOAT); + } + else if (hasExtension(this.gl, COLOR_BUFFER_HALF_FLOAT)) { + this.colorBufferHalfFloatExtension = + this.gl.getExtension(COLOR_BUFFER_HALF_FLOAT); + } + else { + throw new Error('GL context does not support color renderable floats'); + } + } + this.vertexBuffer = createVertexBuffer(this.gl, this.debug); + this.indexBuffer = createIndexBuffer(this.gl, this.debug); + this.framebuffer = createFramebuffer(this.gl, this.debug); + this.textureConfig = + getTextureConfig(this.gl, this.textureHalfFloatExtension); + } + Object.defineProperty(GPGPUContext.prototype, "debug", { + get: function () { + return env().getBool('DEBUG'); + }, + enumerable: true, + configurable: true + }); + GPGPUContext.prototype.dispose = function () { + var _this = this; + if (this.disposed) { + return; + } + if (this.program != null) { + console.warn('Disposing a GPGPUContext that still has a bound WebGLProgram.' + + ' This is probably a resource leak, delete the program with ' + + 'GPGPUContext.deleteProgram before disposing.'); + } + if (this.outputTexture != null) { + console.warn('Disposing a GPGPUContext that still has a bound output matrix ' + + 'texture. This is probably a resource leak, delete the output ' + + 'matrix texture with GPGPUContext.deleteMatrixTexture before ' + + 'disposing.'); + } + var gl = this.gl; + callAndCheck(gl, this.debug, function () { return gl.finish(); }); + callAndCheck(gl, this.debug, function () { return gl.bindFramebuffer(gl.FRAMEBUFFER, null); }); + callAndCheck(gl, this.debug, function () { return gl.deleteFramebuffer(_this.framebuffer); }); + callAndCheck(gl, this.debug, function () { return gl.bindBuffer(gl.ARRAY_BUFFER, null); }); + callAndCheck(gl, this.debug, function () { return gl.bindBuffer(gl.ELEMENT_ARRAY_BUFFER, null); }); + callAndCheck(gl, this.debug, function () { return gl.deleteBuffer(_this.indexBuffer); }); + this.disposed = true; + }; + GPGPUContext.prototype.createFloat32MatrixTexture = function (rows, columns) { + this.throwIfDisposed(); + return createFloat32MatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); + }; + GPGPUContext.prototype.createFloat16MatrixTexture = function (rows, columns) { + this.throwIfDisposed(); + return createFloat16MatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); + }; + GPGPUContext.prototype.createUnsignedBytesMatrixTexture = function (rows, columns) { + this.throwIfDisposed(); + return createUnsignedBytesMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); + }; + GPGPUContext.prototype.uploadPixelDataToTexture = function (texture, pixels) { + this.throwIfDisposed(); + uploadPixelDataToTexture(this.gl, this.debug, texture, pixels); + }; + GPGPUContext.prototype.uploadDenseMatrixToTexture = function (texture, width, height, data) { + this.throwIfDisposed(); + uploadDenseMatrixToTexture(this.gl, this.debug, texture, width, height, data, this.textureConfig); + }; + GPGPUContext.prototype.createFloat16PackedMatrixTexture = function (rows, columns) { + this.throwIfDisposed(); + return createFloat16PackedMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); + }; + GPGPUContext.prototype.createPackedMatrixTexture = function (rows, columns) { + this.throwIfDisposed(); + return createPackedMatrixTexture(this.gl, this.debug, rows, columns, this.textureConfig); + }; + GPGPUContext.prototype.deleteMatrixTexture = function (texture) { + var _this = this; + this.throwIfDisposed(); + if (this.outputTexture === texture) { + unbindColorTextureFromFramebuffer(this.gl, this.debug, this.framebuffer); + this.outputTexture = null; + } + callAndCheck(this.gl, this.debug, function () { return _this.gl.deleteTexture(texture); }); + }; + GPGPUContext.prototype.downloadByteEncodedFloatMatrixFromOutputTexture = function (texture, rows, columns) { + var _this = this; + return this.downloadMatrixDriver(texture, function () { return downloadByteEncodedFloatMatrixFromOutputTexture(_this.gl, _this.debug, rows, columns, _this.textureConfig); }); + }; + GPGPUContext.prototype.downloadPackedMatrixFromBuffer = function (buffer, batch, rows, columns, physicalRows, physicalCols) { + return downloadPackedMatrixFromBuffer(this.gl, buffer, batch, rows, columns, physicalRows, physicalCols, this.textureConfig); + }; + GPGPUContext.prototype.downloadFloat32MatrixFromBuffer = function (buffer, size) { + return downloadFloat32MatrixFromBuffer(this.gl, buffer, size); + }; + GPGPUContext.prototype.createBufferFromTexture = function (texture, rows, columns) { + this.bindTextureToFrameBuffer(texture); + var result = createBufferFromOutputTexture(this.gl, this.debug, rows, columns, this.textureConfig); + this.unbindTextureToFrameBuffer(); + return result; + }; + GPGPUContext.prototype.createAndWaitForFence = function () { + var fenceContext = this.createFence(this.gl); + return this.pollFence(fenceContext); + }; + GPGPUContext.prototype.createFence = function (gl) { + var _this = this; + var query; + var isFencePassed; + if (env().getBool('WEBGL_FENCE_API_ENABLED')) { + var gl2_1 = gl; + var sync_1 = gl2_1.fenceSync(gl2_1.SYNC_GPU_COMMANDS_COMPLETE, 0); + gl.flush(); + isFencePassed = function () { + var status = gl2_1.clientWaitSync(sync_1, 0, 0); + return status === gl2_1.ALREADY_SIGNALED || + status === gl2_1.CONDITION_SATISFIED; + }; + query = sync_1; + } + else if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + query = this.beginQuery(); + this.endQuery(); + isFencePassed = function () { return _this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); }; + } + else { + // If we have no way to fence, return true immediately. This will fire in + // WebGL 1.0 when there is no disjoint query timer. In this case, because + // the fence passes immediately, we'll immediately ask for a download of + // the texture, which will cause the UI thread to hang. + isFencePassed = function () { return true; }; + } + return { query: query, isFencePassed: isFencePassed }; + }; + GPGPUContext.prototype.downloadMatrixFromPackedTexture = function (texture, physicalRows, physicalCols) { + var _this = this; + return this.downloadMatrixDriver(texture, function () { return downloadMatrixFromPackedOutputTexture(_this.gl, _this.debug, physicalRows, physicalCols); }); + }; + GPGPUContext.prototype.createProgram = function (fragmentShaderSource) { + this.throwIfDisposed(); + var gl = this.gl; + var fragmentShader = createFragmentShader(gl, this.debug, fragmentShaderSource); + var vertexShader = createVertexShader$1(gl, this.debug); + var program = createProgram(gl, this.debug); + callAndCheck(gl, this.debug, function () { return gl.attachShader(program, vertexShader); }); + callAndCheck(gl, this.debug, function () { return gl.attachShader(program, fragmentShader); }); + linkProgram(gl, this.debug, program); + if (this.debug) { + validateProgram(gl, this.debug, program); + } + if (!this.vertexAttrsAreBound) { + this.setProgram(program); + this.vertexAttrsAreBound = bindVertexProgramAttributeStreams(gl, this.debug, this.program, this.vertexBuffer); + } + return program; + }; + GPGPUContext.prototype.deleteProgram = function (program) { + var _this = this; + this.throwIfDisposed(); + if (program === this.program) { + this.program = null; + } + if (program != null) { + callAndCheck(this.gl, this.debug, function () { return _this.gl.deleteProgram(program); }); + } + }; + GPGPUContext.prototype.setProgram = function (program) { + var _this = this; + this.throwIfDisposed(); + this.program = program; + if ((this.program != null) && this.debug) { + validateProgram(this.gl, this.debug, this.program); + } + callAndCheck(this.gl, this.debug, function () { return _this.gl.useProgram(program); }); + }; + GPGPUContext.prototype.getUniformLocation = function (program, uniformName, shouldThrow) { + if (shouldThrow === void 0) { shouldThrow = true; } + this.throwIfDisposed(); + if (shouldThrow) { + return getProgramUniformLocationOrThrow(this.gl, this.debug, program, uniformName); + } + else { + return getProgramUniformLocation(this.gl, program, uniformName); + } + }; + GPGPUContext.prototype.getAttributeLocation = function (program, attribute) { + var _this = this; + this.throwIfDisposed(); + return callAndCheck(this.gl, this.debug, function () { return _this.gl.getAttribLocation(program, attribute); }); + }; + GPGPUContext.prototype.getUniformLocationNoThrow = function (program, uniformName) { + this.throwIfDisposed(); + return this.gl.getUniformLocation(program, uniformName); + }; + GPGPUContext.prototype.setInputMatrixTexture = function (inputMatrixTexture, uniformLocation, textureUnit) { + this.throwIfDisposed(); + this.throwIfNoProgram(); + bindTextureToProgramUniformSampler(this.gl, this.debug, this.program, inputMatrixTexture, uniformLocation, textureUnit); + }; + GPGPUContext.prototype.setOutputMatrixTexture = function (outputMatrixTexture, rows, columns) { + this.setOutputMatrixTextureDriver(outputMatrixTexture, columns, rows); + }; + GPGPUContext.prototype.setOutputPackedMatrixTexture = function (outputPackedMatrixTexture, rows, columns) { + this.throwIfDisposed(); + var _a = getPackedMatrixTextureShapeWidthHeight(rows, columns), width = _a[0], height = _a[1]; + this.setOutputMatrixTextureDriver(outputPackedMatrixTexture, width, height); + }; + GPGPUContext.prototype.setOutputMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) { + this.setOutputMatrixWriteRegionDriver(startColumn, startRow, numColumns, numRows); + }; + GPGPUContext.prototype.setOutputPackedMatrixWriteRegion = function (startRow, numRows, startColumn, numColumns) { + throw new Error('setOutputPackedMatrixWriteRegion not implemented.'); + }; + GPGPUContext.prototype.debugValidate = function () { + if (this.program != null) { + validateProgram(this.gl, this.debug, this.program); + } + validateFramebuffer(this.gl); + }; + GPGPUContext.prototype.executeProgram = function () { + this.throwIfDisposed(); + this.throwIfNoProgram(); + var gl = this.gl; + if (this.debug) { + this.debugValidate(); + } + callAndCheck(gl, this.debug, function () { return gl.drawElements(gl.TRIANGLES, 6, gl.UNSIGNED_SHORT, 0); }); + }; + GPGPUContext.prototype.blockUntilAllProgramsCompleted = function () { + var _this = this; + this.throwIfDisposed(); + callAndCheck(this.gl, this.debug, function () { return _this.gl.finish(); }); + }; + GPGPUContext.prototype.getQueryTimerExtension = function () { + if (this.disjointQueryTimerExtension == null) { + this.disjointQueryTimerExtension = + getExtensionOrThrow(this.gl, this.debug, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2 ? + 'EXT_disjoint_timer_query_webgl2' : + 'EXT_disjoint_timer_query'); + } + return this.disjointQueryTimerExtension; + }; + GPGPUContext.prototype.getQueryTimerExtensionWebGL2 = function () { + return this.getQueryTimerExtension(); + }; + GPGPUContext.prototype.getQueryTimerExtensionWebGL1 = function () { + return this.getQueryTimerExtension(); + }; + GPGPUContext.prototype.beginQuery = function () { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { + var gl2 = this.gl; + var ext_1 = this.getQueryTimerExtensionWebGL2(); + var query_1 = gl2.createQuery(); + gl2.beginQuery(ext_1.TIME_ELAPSED_EXT, query_1); + return query_1; + } + var ext = this.getQueryTimerExtensionWebGL1(); + var query = ext.createQueryEXT(); + ext.beginQueryEXT(ext.TIME_ELAPSED_EXT, query); + return query; + }; + GPGPUContext.prototype.endQuery = function () { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') === 2) { + var gl2 = this.gl; + var ext_2 = this.getQueryTimerExtensionWebGL2(); + gl2.endQuery(ext_2.TIME_ELAPSED_EXT); + return; + } + var ext = this.getQueryTimerExtensionWebGL1(); + ext.endQueryEXT(ext.TIME_ELAPSED_EXT); + }; + GPGPUContext.prototype.waitForQueryAndGetTime = function (query) { + return __awaiter(this, void 0, void 0, function () { + var _this = this; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, repeatedTry(function () { return _this.disposed || // while testing contexts are created / disposed + // in rapid succession, so without this check we + // may poll for the query timer indefinitely + _this.isQueryAvailable(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION')); })]; + case 1: + _a.sent(); + return [2 /*return*/, this.getQueryTime(query, env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION'))]; + } + }); + }); + }; + GPGPUContext.prototype.getQueryTime = function (query, queryTimerVersion) { + if (queryTimerVersion === 0) { + return null; + } + if (queryTimerVersion === 2) { + var gl2 = this.gl; + var timeElapsedNanos = gl2.getQueryParameter(query, gl2.QUERY_RESULT); + // Return milliseconds. + return timeElapsedNanos / 1000000; + } + else { + var ext = this.getQueryTimerExtensionWebGL1(); + var timeElapsedNanos = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_EXT); + // Return milliseconds. + return timeElapsedNanos / 1000000; + } + }; + GPGPUContext.prototype.isQueryAvailable = function (query, queryTimerVersion) { + if (queryTimerVersion === 0) { + return true; + } + if (queryTimerVersion === 2) { + var gl2 = this.gl; + var ext = this.getQueryTimerExtensionWebGL2(); + var available = gl2.getQueryParameter(query, gl2.QUERY_RESULT_AVAILABLE); + if (this.disjoint == null) { + this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); + } + return available && !this.disjoint; + } + else { + var ext = this.getQueryTimerExtensionWebGL1(); + var available = ext.getQueryObjectEXT(query, ext.QUERY_RESULT_AVAILABLE_EXT); + if (this.disjoint == null) { + this.disjoint = this.gl.getParameter(ext.GPU_DISJOINT_EXT); + } + return available && !this.disjoint; + } + }; + GPGPUContext.prototype.pollFence = function (fenceContext) { + var _this = this; + return new Promise(function (resolve) { + _this.addItemToPoll(function () { return fenceContext.isFencePassed(); }, function () { return resolve(); }); + }); + }; + GPGPUContext.prototype.pollItems = function () { + // Find the last query that has finished. + var index = linearSearchLastTrue(this.itemsToPoll.map(function (x) { return x.isDoneFn; })); + for (var i = 0; i <= index; ++i) { + var resolveFn = this.itemsToPoll[i].resolveFn; + resolveFn(); + } + this.itemsToPoll = this.itemsToPoll.slice(index + 1); + }; + GPGPUContext.prototype.addItemToPoll = function (isDoneFn, resolveFn) { + var _this = this; + this.itemsToPoll.push({ isDoneFn: isDoneFn, resolveFn: resolveFn }); + if (this.itemsToPoll.length > 1) { + // We already have a running loop that polls. + return; + } + // Start a new loop that polls. + repeatedTry(function () { + _this.pollItems(); + // End the loop if no more items to poll. + return _this.itemsToPoll.length === 0; + }); + }; + GPGPUContext.prototype.bindTextureToFrameBuffer = function (texture) { + this.throwIfDisposed(); + bindColorTextureToFramebuffer(this.gl, this.debug, texture, this.framebuffer); + if (this.debug) { + validateFramebuffer(this.gl); + } + }; + GPGPUContext.prototype.unbindTextureToFrameBuffer = function () { + if (this.outputTexture != null) { + bindColorTextureToFramebuffer(this.gl, this.debug, this.outputTexture, this.framebuffer); + if (this.debug) { + validateFramebuffer(this.gl); + } + } + else { + unbindColorTextureFromFramebuffer(this.gl, this.debug, this.framebuffer); + } + }; + GPGPUContext.prototype.downloadMatrixDriver = function (texture, downloadAndDecode) { + this.bindTextureToFrameBuffer(texture); + var result = downloadAndDecode(); + this.unbindTextureToFrameBuffer(); + return result; + }; + GPGPUContext.prototype.setOutputMatrixTextureDriver = function (outputMatrixTextureMaybePacked, width, height) { + this.throwIfDisposed(); + var gl = this.gl; + bindColorTextureToFramebuffer(gl, this.debug, outputMatrixTextureMaybePacked, this.framebuffer); + if (this.debug) { + validateFramebuffer(gl); + } + this.outputTexture = outputMatrixTextureMaybePacked; + callAndCheck(gl, this.debug, function () { return gl.viewport(0, 0, width, height); }); + callAndCheck(gl, this.debug, function () { return gl.scissor(0, 0, width, height); }); + }; + GPGPUContext.prototype.setOutputMatrixWriteRegionDriver = function (x, y, width, height) { + var _this = this; + this.throwIfDisposed(); + callAndCheck(this.gl, this.debug, function () { return _this.gl.scissor(x, y, width, height); }); + }; + GPGPUContext.prototype.throwIfDisposed = function () { + if (this.disposed) { + throw new Error('Attempted to use disposed GPGPUContext.'); + } + }; + GPGPUContext.prototype.throwIfNoProgram = function () { + if (this.program == null) { + throw new Error('No GPU program is currently set.'); + } + }; + return GPGPUContext; + }()); + /** + * Finds the index of the last true element using linear search. + * Note: We can't do binary search because Chrome expects us to explicitly + * test all fences before download: + * https://github.com/tensorflow/tfjs/issues/1145 + */ + function linearSearchLastTrue(arr) { + var i = 0; + for (; i < arr.length; ++i) { + var isDone = arr[i](); + if (!isDone) { + break; + } + } + return i - 1; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function compileProgram(gpgpu, program, inputs, output) { + var userCode = program.userCode; + var inputInfos = inputs.map(function (input, i) { + var shapeInfo = { + logicalShape: input.shape, + texShape: input.isUniform ? null : input.texData.texShape, + isUniform: input.isUniform, + isPacked: input.isUniform ? false : input.texData.isPacked, + flatOffset: null + }; + if (input.texData != null && input.texData.slice != null && + input.texData.slice.flatOffset > 0) { + shapeInfo.flatOffset = input.texData.slice.flatOffset; + } + return { name: program.variableNames[i], shapeInfo: shapeInfo }; + }); + var inShapeInfos = inputInfos.map(function (x) { return x.shapeInfo; }); + var outShapeInfo = { + logicalShape: output.shape, + texShape: output.texData.texShape, + isUniform: false, + isPacked: output.texData.isPacked, + flatOffset: null + }; + var source = makeShader(inputInfos, outShapeInfo, userCode, program.packedInputs); + var webGLProgram = gpgpu.createProgram(source); + // Add special uniforms (NAN, INFINITY) + var infLoc = null; + var nanLoc = gpgpu.getUniformLocation(webGLProgram, 'NAN', false); + if (env().getNumber('WEBGL_VERSION') === 1) { + infLoc = gpgpu.getUniformLocation(webGLProgram, 'INFINITY', false); + } + // Add user-defined uniforms + var uniformLocations = {}; + for (var i = 0; i < program.variableNames.length; i++) { + var varName = program.variableNames[i]; + var shouldThrow = false; + uniformLocations[varName] = + gpgpu.getUniformLocation(webGLProgram, varName, shouldThrow); + uniformLocations["offset" + varName] = + gpgpu.getUniformLocation(webGLProgram, "offset" + varName, shouldThrow); + } + return { + program: program, + source: source, + webGLProgram: webGLProgram, + uniformLocations: uniformLocations, + inShapeInfos: inShapeInfos, + outShapeInfo: outShapeInfo, + infLoc: infLoc, + nanLoc: nanLoc, + }; + } + function validateBinaryAndProgram(shapeInfos, inputs) { + if (shapeInfos.length !== inputs.length) { + throw Error("Binary was compiled with " + shapeInfos.length + " inputs, but " + + ("was executed with " + inputs.length + " inputs")); + } + shapeInfos.forEach(function (s, i) { + var shapeA = s.logicalShape; + var input = inputs[i]; + var shapeB = input.shape; + if (!arraysEqual(shapeA, shapeB)) { + throw Error("Binary was compiled with different shapes than " + + ("the current args. Shapes " + shapeA + " and " + shapeB + " must match")); + } + // The input is uploaded as uniform. + if (s.isUniform && input.isUniform) { + return; + } + var texShapeA = s.texShape; + var texShapeB = input.isUniform ? null : input.texData.texShape; + if (!arraysEqual(texShapeA, texShapeB)) { + throw Error("Binary was compiled with different texture shapes than the" + + (" current args. Shape " + texShapeA + " and " + texShapeB + " must match")); + } + }); + } + function runProgram(gpgpu, binary, inputs, output, customSetup) { + validateBinaryAndProgram(binary.inShapeInfos, inputs); + validateBinaryAndProgram([binary.outShapeInfo], [output]); + var outTex = output.texData.texture; + var outTexShape = output.texData.texShape; + if (output.texData.isPacked) { + gpgpu.setOutputPackedMatrixTexture(outTex, outTexShape[0], outTexShape[1]); + } + else { + gpgpu.setOutputMatrixTexture(outTex, outTexShape[0], outTexShape[1]); + } + gpgpu.setProgram(binary.webGLProgram); + // Set special uniforms (NAN, INFINITY) + if (env().getNumber('WEBGL_VERSION') === 1) { + if (binary.infLoc !== null) { + gpgpu.gl.uniform1f(binary.infLoc, Infinity); + } + } + if (binary.nanLoc !== null) { + gpgpu.gl.uniform1f(binary.nanLoc, NaN); + } + // Set user-defined inputs + inputs.forEach(function (input, i) { + var varName = binary.program.variableNames[i]; + var varLoc = binary.uniformLocations[varName]; + var varOffsetLoc = binary.uniformLocations["offset" + varName]; + if (varLoc == null) { + // The compiler inferred that this variable is not used in this shader. + return; + } + if (input.isUniform) { + // Upload the values of the tensor as uniform. + if (sizeFromShape(input.shape) < 2) { + gpgpu.gl.uniform1f(varLoc, input.uniformValues[0]); + } + else { + var vals = input.uniformValues; + if (!(vals instanceof Float32Array)) { + vals = new Float32Array(vals); + } + gpgpu.gl.uniform1fv(varLoc, vals); + } + return; + } + // If the input was sliced, upload the flat offset index. + if (input.texData.slice != null && varOffsetLoc != null) { + gpgpu.gl.uniform1i(varOffsetLoc, input.texData.slice.flatOffset); + } + gpgpu.setInputMatrixTexture(input.texData.texture, varLoc, i); + }); + if (customSetup != null) { + customSetup(gpgpu, binary.webGLProgram); + } + gpgpu.executeProgram(); + } + function makeShaderKey(program, inputs, output) { + var keyInputs = ''; + inputs.concat(output).forEach(function (x) { + var hasOffset = x.texData != null && x.texData.slice != null && + x.texData.slice.flatOffset > 0; + var texShape = x.isUniform ? 'uniform' : x.texData.texShape; + keyInputs += x.shape + "_" + texShape + "_" + hasOffset; + }); + var keyUserCode = program.userCode; + var key = program.constructor.name; + // Fast string concat. See https://jsperf.com/string-concatenation/14. + key += '_' + keyInputs + '_' + keyUserCode; + return key; + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var Im2ColPackedProgram = /** @class */ (function () { + function Im2ColPackedProgram(outputShape, inputShape, convInfo) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = outputShape; + var filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, strideWidth = convInfo.strideWidth, strideHeight = convInfo.strideHeight, padInfo = convInfo.padInfo, outWidth = convInfo.outWidth, dilationWidth = convInfo.dilationWidth, dilationHeight = convInfo.dilationHeight, dataFormat = convInfo.dataFormat; + var left = padInfo.left, top = padInfo.top; + var itemsPerBlockRow = inChannels * filterWidth; + var glsl = getGlslDifferences(); + var isChannelsLast = dataFormat === 'channelsLast'; + var rowDim = isChannelsLast ? 0 : 1; + var colDim = isChannelsLast ? 1 : 2; + var unrolled = ""; + for (var row = 0; row <= 1; row++) { + for (var col = 0; col <= 1; col++) { + unrolled += "\n blockIndex = rc.y + " + col + ";\n pos = rc.x + " + row + ";\n\n if(blockIndex < " + outputShape[1] + " && pos < " + outputShape[0] + ") {\n offsetY = int(blockIndex / (" + outWidth + ")) * " + strideHeight + " - " + top + ";\n d0 = offsetY + " + dilationHeight + " * (pos / " + itemsPerBlockRow + ");\n\n if(d0 < " + inputShape[rowDim] + " && d0 >= 0) {\n\n offsetX = int(mod(float(blockIndex), " + outWidth + ".) * " + strideWidth + ". - " + left + ".);\n d1 = offsetX + " + dilationWidth + " * (int(mod(float(pos), " + itemsPerBlockRow + ".) / " + inChannels + ".));\n\n if(d1 < " + inputShape[colDim] + " && d1 >= 0) {\n\n ch = int(mod(float(pos), " + inChannels + ".));\n\n if (" + isChannelsLast + ") {\n innerDims = vec2(d1, ch);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(d0, int(innerDims.x),\n int(innerDims.y)), innerDims);\n } else {\n innerDims = vec2(d0, d1);\n result[" + (row * 2 + col) + "] = getChannel(\n getA(ch, int(innerDims.x),\n int(innerDims.y)), innerDims);\n }\n }\n }\n }\n "; + } + } + this.userCode = "\n void main() {\n ivec2 rc = getOutputCoords();\n\n vec4 result = vec4(0);\n\n int blockIndex, pos, offsetY, d0, offsetX, d1, ch;\n vec2 innerDims;\n\n " + unrolled + "\n\n " + glsl.output + " = result;\n }\n "; + } + return Im2ColPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var LRNProgram = /** @class */ (function () { + function LRNProgram(xShape, radius, bias, alpha, beta) { + this.variableNames = ['x']; + this.outputShape = []; + var rad = radius; + var maxD = xShape[3] - 1; + this.outputShape = xShape; + // optimize pow(bias + alpha * sum, -beta) + // src: https://github.com/tensorflow/tensorflow/.. + // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. + // tensorflow/core/kernels/mkl_lrn_op.cc#L320 + var powOperator; + var basis = "float(" + bias + ") + float(" + alpha + ") * sum"; + if (beta === 0.5) { + powOperator = "inversesqrt(" + basis + ")"; + } + else if (beta === 1.0) { + powOperator = "1.0/(" + basis + ")"; + } + else { + powOperator = "exp(log(" + basis + ") * float(-" + beta + "));"; + } + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n int d = coords[3];\n float x = getX(b, r, c, d);\n float sum = 0.0;\n for (int j = -" + rad + "; j <= " + rad + "; j++) {\n int idx = d + j;\n if (idx >= 0 && idx <= " + maxD + ") {\n float z = getX(b, r, c, idx);\n sum += z * z;\n }\n }\n float val = x * " + powOperator + ";\n setOutput(val);\n }\n "; + } + return LRNProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var LRNGradProgram = /** @class */ (function () { + function LRNGradProgram(inputShape, depthRadius, bias, alpha, beta) { + this.variableNames = ['inputImage', 'outputImage', 'dy']; + this.outputShape = []; + this.outputShape = inputShape; + this.depth = inputShape[3]; + this.depthRadius = depthRadius; + this.bias = bias; + this.alpha = alpha; + this.beta = beta; + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int r = coords[1];\n int c = coords[2];\n\n float result = 0.0;\n for (int d = 0; d < " + this.depth + "; ++d) {\n int depthBegin = int(max(0.0, float(d - " + depthRadius + ")));\n int depthEnd = int(min(float(" + this.depth + "),\n float(d + " + depthRadius + " + 1)));\n\n const int MIN_DEPTH_BEGIN = 0;\n const int MAX_DEPTH_END = " + this.depth + ";\n\n float norm = 0.0;\n for (int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k) {\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd) {\n norm += getInputImage(b, r, c, k) * getInputImage(b, r, c, k);\n }\n else {\n break;\n }\n }\n\n norm = float(" + alpha + ") * norm + float(" + bias + ");\n\n for(int k = MIN_DEPTH_BEGIN; k < MAX_DEPTH_END; ++k){\n if (k < depthBegin){\n continue;\n }\n else if (k >= depthBegin && k < depthEnd){\n float dyi = -2.0 * float(" + alpha + ")\n * float(" + beta + ")\n * getInputImage(b ,r ,c, k) * getOutputImage(b, r, c, d)\n / norm;\n if (k == d) {\n dyi += pow(norm, -1.0 * " + beta + ");\n }\n if (k == coords[3]) {\n dyi *= getDy(b, r, c, d);\n result += dyi;\n }\n }\n else {\n break;\n }\n }\n }\n setOutput(result);\n }\n "; + } + return LRNGradProgram; + }()); + + /** + * @license + * Copyright 2019 Google LLC All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var LRNPackedProgram = /** @class */ (function () { + function LRNPackedProgram(xShape, radius, bias, alpha, beta) { + this.variableNames = ['x']; + this.outputShape = []; + this.packedInputs = true; + this.packedOutput = true; + var rad = radius; + var maxD = xShape[3] - 1; + this.outputShape = xShape; + // optimize pow(bias + alpha * sum, -beta) + // src: https://github.com/tensorflow/tensorflow/.. + // blob/26033a1644a9c4a5fbe3170ab2e864b6a4ccd4ca/.. + // tensorflow/core/kernels/mkl_lrn_op.cc#L320 + var powOperator; + var basis = "float(" + bias + ") + float(" + alpha + ") * sum"; + if (beta === 0.5) { + powOperator = "inversesqrt(" + basis + ")"; + } + else if (beta === 1.0) { + powOperator = "1.0/(" + basis + ")"; + } + else { + powOperator = "exp(log(" + basis + ") * float(-" + beta + "));"; + } + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords.x;\n int r = coords.y;\n int c = coords.z;\n int d = coords.w;\n\n bool hasNextCol = d < " + this.outputShape[3] + ";\n bool hasNextRow = c < " + this.outputShape[2] + ";\n\n vec4 sum = vec4(0.);\n vec4 xFragAtOutputCoords = getX(b, r, c, d);\n\n vec4 xAtOutputCoords = vec4(\n getChannel(xFragAtOutputCoords, vec2(c, d)),\n hasNextCol ?\n getChannel(xFragAtOutputCoords, vec2(c, d + 1)) : 0.0,\n hasNextRow ?\n getChannel(xFragAtOutputCoords , vec2(c + 1, d)) : 0.0,\n (hasNextRow && hasNextCol) ?\n getChannel(xFragAtOutputCoords, vec2(c + 1, d + 1)) : 0.0\n );\n\n int firstChannel = d - " + rad + ";\n vec2 cache = vec2(0.);\n if(firstChannel >= 0){\n vec4 firstChannelFrag = getX(b, r, c, firstChannel);\n cache.x = getChannel(firstChannelFrag, vec2(c, firstChannel));\n if(hasNextRow){\n cache.y = getChannel(firstChannelFrag, vec2(c + 1, firstChannel));\n }\n }\n\n ivec2 depth = ivec2(d, d + 1);\n for (int j = - " + rad + "; j <= " + rad + "; j++) {\n ivec2 idx = depth + j;\n bvec2 aboveLowerBound = greaterThanEqual(idx, ivec2(0));\n bvec2 belowUpperBound = lessThanEqual(idx, ivec2(" + maxD + "));\n\n bool depthInRange = aboveLowerBound.x && belowUpperBound.x;\n bool depthPlusOneInRange = aboveLowerBound.y && belowUpperBound.y;\n\n if(depthInRange || depthPlusOneInRange){\n vec4 z = vec4(0.);\n vec4 xFragAtCurrentDepth;\n z.xz = cache.xy;\n if(depthPlusOneInRange && hasNextCol){\n xFragAtCurrentDepth = idx.y != d ?\n getX(b, r, c, idx.y) : xFragAtOutputCoords;\n z.y = getChannel(xFragAtCurrentDepth, vec2(c, idx.y));\n if(hasNextRow){\n z.w = getChannel(xFragAtCurrentDepth, vec2(c + 1, idx.y));\n }\n }\n cache.xy = z.yw;\n sum += z * z;\n }\n }\n vec4 result = xAtOutputCoords * " + powOperator + ";\n setOutput(result);\n }\n "; + } + return LRNPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var MaxPool2DBackpropProgram = /** @class */ (function () { + function MaxPool2DBackpropProgram(convInfo) { + this.variableNames = ['dy', 'maxPos']; + this.outputShape = convInfo.inShape; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var lastIndex = effectiveFilterHeight * effectiveFilterWidth - 1; + this.userCode = "\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n\n ivec2 dyRCCorner = coords.yz - pads;\n int dyRCorner = dyRCCorner.x;\n int dyCCorner = dyRCCorner.y;\n\n // Convolve dy(?, ?, d) with pos mask(:, :, d) to get dx(xR, xC, d).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 || fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + "; wC++) {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(b, idyR, idyC, d);\n int maxPosValue = " + lastIndex + " - int(getMaxPos(b, idyR, idyC, d));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue = wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n setOutput(dotProd);\n }\n "; + } + return MaxPool2DBackpropProgram; + }()); + var MaxPool3DBackpropProgram = /** @class */ (function () { + function MaxPool3DBackpropProgram(convInfo) { + this.variableNames = ['dy', 'maxPos']; + this.outputShape = convInfo.inShape; + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var lastIndex = effectiveFilterDepth * effectiveFilterHeight * effectiveFilterWidth - 1; + this.userCode = "\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 dyCorner = ivec3(coords.y, coords.z, coords.w) - pads;\n int dyDCorner = dyCorner.x;\n int dyRCorner = dyCorner.y;\n int dyCCorner = dyCorner.z;\n\n // Convolve dy(?, ?, ?, ch) with pos mask(:, :, :, d) to get\n // dx(xD, xR, xC, ch).\n // ? = to be determined. : = across all values in that axis.\n float dotProd = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n float dyD = float(dyDCorner + wD) / " + strideDepth + ".0;\n\n if (dyD < 0.0 || dyD >= " + convInfo.outDepth + ".0 || fract(dyD) > 0.0) {\n continue;\n }\n int idyD = int(dyD);\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n float dyR = float(dyRCorner + wR) / " + strideHeight + ".0;\n\n if (dyR < 0.0 || dyR >= " + convInfo.outHeight + ".0 ||\n fract(dyR) > 0.0) {\n continue;\n }\n int idyR = int(dyR);\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n float dyC = float(dyCCorner + wC) / " + strideWidth + ".0;\n\n if (dyC < 0.0 || dyC >= " + convInfo.outWidth + ".0 ||\n fract(dyC) > 0.0) {\n continue;\n }\n int idyC = int(dyC);\n\n float dyValue = getDy(batch, idyD, idyR, idyC, ch);\n int maxPosValue = " + lastIndex + " -\n int(getMaxPos(batch, idyD, idyR, idyC, ch));\n\n // Get the current value, check it against the value from the\n // position matrix.\n int curPosValue =\n wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC;\n float mask = float(maxPosValue == curPosValue ? 1.0 : 0.0);\n\n dotProd += dyValue * mask;\n }\n }\n }\n setOutput(dotProd);\n }\n "; + } + return MaxPool3DBackpropProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var MatMulPackedProgram = /** @class */ (function () { + function MatMulPackedProgram(aShape, outputShape, transposeA, transposeB, addBias, activation, hasPreluActivation) { + if (transposeA === void 0) { transposeA = false; } + if (transposeB === void 0) { transposeB = false; } + if (addBias === void 0) { addBias = false; } + if (activation === void 0) { activation = null; } + if (hasPreluActivation === void 0) { hasPreluActivation = false; } + this.variableNames = ['matrixA', 'matrixB']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = outputShape; + var sharedDim = transposeA ? aShape[1] : aShape[2]; + var sharedDimensionPacked = Math.ceil(sharedDim / 2); + var aSample = transposeA ? 'i * 2, rc.y' : 'rc.y, i * 2'; + var bSample = transposeB ? 'rc.z, i * 2' : 'i * 2, rc.z'; + var aSwizzle = transposeA ? ['a.xxyy', 'a.zzww'] : ['a.xxzz', 'a.yyww']; + var bSwizzle = transposeB ? ['b.xzxz', 'b.ywyw'] : ['b.xyxy', 'b.zwzw']; + var activationSnippet = '', applyActivationSnippet = ''; + if (activation) { + if (hasPreluActivation) { + activationSnippet = "vec4 activation(vec4 a) {\n vec4 b = getPreluActivationWeightsAtOutCoords();\n " + activation + "\n }"; + } + else { + activationSnippet = "vec4 activation(vec4 x) {\n " + activation + "\n }"; + } + applyActivationSnippet = "result = activation(result);"; + } + var addBiasSnippet = addBias ? 'result += getBiasAtOutCoords();' : ''; + if (addBias) { + this.variableNames.push('bias'); + } + if (hasPreluActivation) { + this.variableNames.push('preluActivationWeights'); + } + this.userCode = "\n " + activationSnippet + "\n\n const float sharedDimension = " + sharedDimensionPacked + ".0;\n\n vec4 dot2x2ARowBCol(ivec3 rc) {\n vec4 result = vec4(0);\n for (int i = 0; i < " + sharedDimensionPacked + "; i++) {\n vec4 a = getMatrixA(rc.x, " + aSample + ");\n vec4 b = getMatrixB(rc.x, " + bSample + ");\n\n // These swizzled products need to be separately added.\n // See: https://github.com/tensorflow/tfjs/issues/1735\n result += (" + aSwizzle[0] + " * " + bSwizzle[0] + ");\n result += (" + aSwizzle[1] + " * " + bSwizzle[1] + ");\n }\n return result;\n }\n\n void main() {\n ivec3 rc = getOutputCoords();\n vec4 result = dot2x2ARowBCol(rc);\n\n " + addBiasSnippet + "\n\n " + applyActivationSnippet + "\n\n setOutput(result);\n }\n "; + } + return MatMulPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var MultinomialProgram = /** @class */ (function () { + function MultinomialProgram(batchSize, numOutcomes, numSamples) { + this.variableNames = ['probs']; + this.outputShape = [batchSize, numSamples]; + this.userCode = "\n uniform float seed;\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n\n float r = random(seed);\n float cdf = 0.0;\n\n for (int i = 0; i < " + (numOutcomes - 1) + "; i++) {\n cdf += getProbs(batch, i);\n\n if (r < cdf) {\n setOutput(float(i));\n return;\n }\n }\n\n // If no other event happened, last event happened.\n setOutput(float(" + (numOutcomes - 1) + "));\n }\n "; + } + MultinomialProgram.prototype.getCustomSetupFunc = function (seed) { + var _this = this; + return function (gpgpu, webGLProgram) { + if (_this.seedLoc == null) { + _this.seedLoc = gpgpu.getUniformLocation(webGLProgram, 'seed'); + } + gpgpu.gl.uniform1f(_this.seedLoc, seed); + }; + }; + return MultinomialProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var OneHotProgram = /** @class */ (function () { + function OneHotProgram(numIndices, depth, onValue, offValue) { + this.variableNames = ['indices']; + this.outputShape = [numIndices, depth]; + this.userCode = "\n void main() {\n ivec2 coords = getOutputCoords();\n int index = round(getIndices(coords.x));\n setOutput(mix(float(" + offValue + "), float(" + onValue + "),\n float(index == coords.y)));\n }\n "; + } + return OneHotProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PackProgram = /** @class */ (function () { + function PackProgram(outputShape) { + this.variableNames = ['A']; + this.packedInputs = false; + this.packedOutput = true; + // Only input / output 3D tensors. + this.outputShape = outputShape; + var rank = outputShape.length; + if (rank === 0) { + this.userCode = "\n void main() {\n setOutput(vec4(getA(), 0., 0., 0.));\n }\n "; + } + else { + var channels = getChannels('rc', rank); + var dtype = getCoordsDataType(rank); + var outOfBoundsCondition = getOutOfBoundsCondition(rank, outputShape, channels); + var setup = getSetup(rank, outputShape[outputShape.length - 1], outputShape[outputShape.length - 2], channels); + var output = getOutput(outputShape, channels); + this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n\n if(" + outOfBoundsCondition + ") {\n setOutput(vec4(0));\n } else {\n " + setup + "\n\n setOutput(vec4(" + output + "));\n }\n }\n "; + } + } + return PackProgram; + }()); + function getSourceCoordsArr(rank, dims) { + var coords = []; + for (var row = 0; row <= 1; row++) { + for (var col = 0; col <= 1; col++) { + var coord = (row === 0 ? 'r' : 'rp1') + ", " + (col === 0 ? 'c' : 'cp1'); + for (var d = 2; d < rank; d++) { + coord = dims[dims.length - 1 - d] + "," + coord; + } + coords.push(coord); + } + } + return coords; + } + function getOutOfBoundsCondition(rank, shape, dims) { + if (rank === 1) { + return "rc > " + shape[0]; + } + var cond = ''; + for (var i = rank - 2; i < rank; i++) { + cond += dims[i] + " >= " + shape[i]; + if (i < rank - 1) { + cond += '||'; + } + } + return cond; + } + function getSetup(rank, cols, rows, dims) { + if (rank === 1) { + return ''; + } + var innerDims = dims.slice(-2); + return "\n int r = " + innerDims[0] + ";\n int c = " + innerDims[1] + ";\n int rp1 = r + 1;\n int cp1 = c + 1;\n\n bool cEdge = cp1 >= " + cols + ";\n bool rEdge = rp1 >= " + rows + ";\n "; + } + function getOutput(shape, dims) { + var rank = shape.length; + var sourceCoords = getSourceCoordsArr(rank, dims); + if (rank === 1) { + return "getA(rc),\n rc + 1 >= " + shape[0] + " ? 0. : getA(rc + 1),\n 0, 0"; + } + return "getA(" + sourceCoords[0] + "),\n cEdge ? 0. : getA(" + sourceCoords[1] + "),\n rEdge ? 0. : getA(" + sourceCoords[2] + "),\n rEdge || cEdge ? 0. : getA(" + sourceCoords[3] + ")"; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PadProgram = /** @class */ (function () { + function PadProgram(xShape, paddings, constantValue) { + this.variableNames = ['x']; + this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); + var rank = xShape.length; + var type = getCoordsDataType(rank); + var start = paddings.map(function (p) { return p[0]; }).join(','); + var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); + var unpackedCoords = ['coords[0]', 'coords[1]', 'coords[2]', 'coords[3]'].slice(0, rank); + if (rank === 1) { + this.userCode = "\n int start = " + start + ";\n int end = " + end + ";\n\n void main() {\n int outC = getOutputCoords();\n if (outC < start || outC >= end) {\n setOutput(float(" + constantValue + "));\n } else {\n setOutput(getX(outC - start));\n }\n }\n "; + return; + } + this.userCode = "\n " + type + " start = " + type + "(" + start + ");\n " + type + " end = " + type + "(" + end + ");\n\n void main() {\n " + type + " outC = getOutputCoords();\n if (any(lessThan(outC, start)) || any(greaterThanEqual(outC, end))) {\n setOutput(float(" + constantValue + "));\n } else {\n " + type + " coords = outC - start;\n setOutput(getX(" + unpackedCoords + "));\n }\n }\n "; + } + return PadProgram; + }()); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PadPackedProgram = /** @class */ (function () { + function PadPackedProgram(xShape, paddings, constantValue) { + this.variableNames = ['x']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + xShape[i] + p[1]; } /* afterPad */); + var rank = xShape.length; + var dtype = getCoordsDataType(rank); + var start = paddings.map(function (p) { return p[0]; }).join(','); + var end = paddings.map(function (p, i) { return p[0] + xShape[i]; }).join(','); + var coords = getChannels('rc', rank); + var source = getChannels('source', rank); + var cLimit = coords[rank - 1] + " < " + this.outputShape[rank - 1]; + var innerDims = rank === 1 ? 'source' : "vec2(" + source.slice(-2).join() + ")"; + var componentSetup = [ + dtype + " rc = outputLoc;", coords[rank - 1] + " += 1;\n if(" + cLimit + ") {\n ", + rank === 1 ? '' : "}\n rc = outputLoc;\n " + coords[rank - 2] + " += 1;\n if(" + coords[rank - 2] + " < " + this.outputShape[rank - 2] + ") {", + rank === 1 ? '' : " " + coords[rank - 1] + " += 1;\n if(" + cLimit + ") {" + ]; + var paddingArea = rank === 1 ? + 'rc < start || rc >= end' : + 'any(lessThan(rc, start)) || any(greaterThanEqual(rc, end))'; + var mainLoop = ''; + for (var i = 0, j = rank === 1 ? 2 : 4; i < j; i++) { + mainLoop += "\n " + componentSetup[i] + "\n if (" + paddingArea + ") {\n result[" + i + "] = float(" + constantValue + ");\n } else {\n " + dtype + " source = rc - start;\n result[" + i + "] = getChannel(getX(" + source.join() + "), " + innerDims + ");\n }\n "; + } + mainLoop += (rank === 1 ? "} " : "}}"); + this.userCode = "\n const " + dtype + " start = " + dtype + "(" + start + ");\n const " + dtype + " end = " + dtype + "(" + end + ");\n\n void main() {\n " + dtype + " outputLoc = getOutputCoords();\n vec4 result = vec4(0.);\n " + mainLoop + "\n setOutput(result);\n }\n "; + } + return PadPackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var Pool2DProgram = /** @class */ (function () { + function Pool2DProgram(convInfo, poolType, computePositions) { + this.variableNames = ['x']; + if (poolType === 'avg' && computePositions) { + throw new Error('Cannot compute positions for average pool.'); + } + var filterWidth = convInfo.filterWidth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + this.outputShape = convInfo.outShape; + var isAvgPool = poolType === 'avg'; + var initializationValue = '0.0'; + if (!isAvgPool) { + // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. + initializationValue = '-1.0 / 1e-20'; + } + if (computePositions) { + var compareOp_1 = '>='; + this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n float avgValue = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xR, xC, d);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_1 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition = wR * " + effectiveFilterWidth + " + wC;\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n "; + return; + } + var compareOp = 'max'; + var returnValue = poolType + "(" + poolType + "(" + poolType + "(" + + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; + if (poolType === 'avg') { + returnValue = "avgValue / count"; + } + var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; + var filterWidthVec4Remainder = filterWidth % 4; + var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n "; + this.userCode = "\n const ivec2 strides = ivec2(" + strideHeight + ", " + strideWidth + ");\n const ivec2 pads = ivec2(" + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xR, int xC, int d) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xR, xC, d);\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int batch = coords[0];\n int d = coords[3];\n\n ivec2 xRCCorner = coords.yz * strides - pads;\n int xRCorner = xRCCorner.x;\n int xCCorner = xRCCorner.y;\n\n // max/min x(?, ?, d) to get y(yR, yC, d).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n getValue(batch, xR, xC + 3 * " + dilationWidth + ", d)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xR, xC, d),\n getValue(batch, xR, xC + " + dilationWidth + ", d),\n getValue(batch, xR, xC + 2 * " + dilationWidth + ", d),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n "; + } + return Pool2DProgram; + }()); + var Pool3DProgram = /** @class */ (function () { + function Pool3DProgram(convInfo, poolType, computePositions) { + this.variableNames = ['x']; + if (poolType === 'avg' && computePositions) { + throw new Error('Cannot compute positions for average pool.'); + } + var filterWidth = convInfo.filterWidth; + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = convInfo.padInfo.front; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + this.outputShape = convInfo.outShape; + var isAvgPool = poolType === 'avg'; + var initializationValue = '0.0'; + if (!isAvgPool) { + // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. + initializationValue = '-1.0 / 1e-20'; + } + if (computePositions) { + var compareOp_2 = '>='; + this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, ch) to get y(yD, yR, yC, ch).\n // ? = to be determined\n float minMaxValue = 0.0;\n float minMaxValueFound = 0.0;\n int minMaxPosition = 0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + effectiveFilterWidth + ";\n wC += " + dilationWidth + ") {\n int xC = xCCorner + wC;\n\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n continue;\n }\n\n float value = getX(batch, xD, xR, xC, ch);\n\n // If a min / max value has already been found, use it. If not,\n // use the current value.\n float currMinMaxValue = mix(\n value, minMaxValue, minMaxValueFound);\n if (value " + compareOp_2 + " currMinMaxValue) {\n minMaxValue = value;\n minMaxValueFound = 1.0;\n minMaxPosition =\n wD * " + effectiveFilterHeight + " * " + effectiveFilterWidth + " +\n wR * " + effectiveFilterWidth + " + wC;;\n }\n }\n }\n }\n setOutput(float(minMaxPosition));\n }\n "; + return; + } + var compareOp = 'max'; + var returnValue = poolType + "(" + poolType + "(" + poolType + "(" + + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; + if (poolType === 'avg') { + returnValue = "avgValue / count"; + } + var filterWidthNearestVec4 = Math.floor(filterWidth / 4) * 4; + var filterWidthVec4Remainder = filterWidth % 4; + var updateSnippet = "\n if (" + isAvgPool + ") {\n avgValue += dot(values, ones);\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n "; + this.userCode = "\n const ivec3 strides =\n ivec3(" + strideDepth + ", " + strideHeight + ", " + strideWidth + ");\n const ivec3 pads = ivec3(" + padFront + ", " + padTop + ", " + padLeft + ");\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float count = 0.0;\n\n float getValue(int batch, int xD, int xR, int xC, int ch) {\n if (xC < 0 || xC >= " + convInfo.inWidth + ") {\n return initializationValue;\n }\n count += 1.0;\n return getX(batch, xD, xR, xC, ch);\n }\n\n void main() {\n ivec5 coords = getOutputCoords();\n int batch = coords.x;\n int ch = coords.u;\n\n ivec3 xCorner = ivec3(coords.y, coords.z, coords.w) * strides - pads;\n int xDCorner = xCorner.x;\n int xRCorner = xCorner.y;\n int xCCorner = xCorner.z;\n\n // max/min x(?, ?, ?, d) to get y(yD, yR, yC, ch).\n // ? = to be determined\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float avgValue = 0.0;\n count = 0.0;\n\n for (int wD = 0; wD < " + effectiveFilterDepth + ";\n wD += " + dilationDepth + ") {\n int xD = xDCorner + wD;\n\n if (xD < 0 || xD >= " + convInfo.inDepth + ") {\n continue;\n }\n\n for (int wR = 0; wR < " + effectiveFilterHeight + ";\n wR += " + dilationHeight + ") {\n int xR = xRCorner + wR;\n\n if (xR < 0 || xR >= " + convInfo.inHeight + ") {\n continue;\n }\n\n for (int wC = 0; wC < " + filterWidthNearestVec4 + "; wC += 4) {\n int xC = xCCorner + wC * " + dilationWidth + ";\n\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 3 * " + dilationWidth + ", ch)\n );\n\n " + updateSnippet + "\n }\n\n int xC = xCCorner + " + filterWidthNearestVec4 + ";\n if (" + (filterWidthVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (filterWidthVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, xD, xR, xC, ch),\n getValue(batch, xD, xR, xC + " + dilationWidth + ", ch),\n getValue(batch, xD, xR, xC + 2 * " + dilationWidth + ", ch),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n }\n setOutput(" + returnValue + ");\n }\n }\n "; + } + return Pool3DProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ReduceProgram = /** @class */ (function () { + function ReduceProgram(reduceInfo, reduceType) { + this.variableNames = ['x']; + var windowSize = reduceInfo.windowSize; + var batchSize = reduceInfo.batchSize; + var inSize = reduceInfo.inSize; + var outSize = Math.ceil(inSize / windowSize); + this.outputShape = [batchSize, outSize]; + var initializationValue = '0.0'; + var compareOp = ""; + if (reduceType === 'prod') { + initializationValue = '1.0'; + } + else if (reduceType === 'min') { + // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. + initializationValue = '1.0 / 1e-20'; + compareOp = "min"; + } + else if (reduceType === 'max') { + // WebGL on Firefox Linux can't compile 1/0 so we do 1/eps. + initializationValue = '-1.0 / 1e-20'; + compareOp = "max"; + } + var returnValue = reduceType + "(" + reduceType + "(" + reduceType + "(" + + 'minMaxValue[0], minMaxValue[1]), minMaxValue[2]), minMaxValue[3])'; + if (reduceType === 'sum') { + returnValue = "sumValue"; + } + else if (reduceType === 'prod') { + returnValue = "prodValue"; + } + else if (reduceType === 'all') { + returnValue = "allValue"; + } + else if (reduceType === 'any') { + returnValue = "anyValue"; + } + var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; + var windowSizeVec4Remainder = windowSize % 4; + var updateSnippet = "\n if (" + (reduceType === 'sum') + ") {\n sumValue += dot(values, ones);\n } else if (" + (reduceType === 'prod') + ") {\n vec2 tmp = vec2(values[0], values[1]) * vec2(values[2], values[3]);\n prodValue *= tmp[0] * tmp[1];\n } else {\n minMaxValue = " + compareOp + "(values, minMaxValue);\n }\n "; + var vecType = "vec4"; + if (reduceType === 'all') { + initializationValue = '1.0'; + updateSnippet = "\n bool reducedAllValue = all(values);\n float floatedReducedAllValue = float(reducedAllValue);\n allValue = float(allValue >= 1.0 && floatedReducedAllValue >= 1.0);\n "; + vecType = "bvec4"; + } + else if (reduceType === 'any') { + initializationValue = '0.0'; + updateSnippet = "\n bool reducedAnyValue = any(values);\n float floatedReducedAnyValue = float(reducedAnyValue);\n anyValue = float(anyValue >= 1.0 || floatedReducedAnyValue >= 1.0);\n "; + vecType = "bvec4"; + } + var checkOutOfBounds = ''; + if (inSize % windowSize > 0) { + checkOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n "; + } + this.userCode = "\n const float initializationValue = " + initializationValue + ";\n const vec4 ones = vec4(1.0, 1.0, 1.0, 1.0);\n\n float getValue(int batch, int inIdx) {\n " + checkOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = outIdx * " + windowSize + ";\n\n vec4 minMaxValue = vec4(" + initializationValue + ");\n float prodValue = 1.0;\n float sumValue = 0.0;\n float allValue = 1.0;\n float anyValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n " + vecType + " values = " + vecType + "(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n "; + } + return ReduceProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ReshapePackedProgram = /** @class */ (function () { + function ReshapePackedProgram(outputShape, inputShape) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = outputShape; + var mainLoop = ""; + for (var i = 0; i < 4; i++) { + var thisRC = "thisRC = rc;"; + if (i % 2 === 1) { + thisRC += "thisRC.z += 1;"; + } + if (i > 1) { + thisRC += "thisRC.y += 1;"; + } + mainLoop += "\n " + thisRC + "\n " + (i > 0 ? "if(thisRC.y < rows && thisRC.z < cols){" : '') + "\n int flatIndex = getFlatIndex(thisRC);\n\n ivec3 inputRC = inputCoordsFromReshapedOutCoords(flatIndex);\n vec2 inputRCInnerDims = vec2(float(inputRC.y),float(inputRC.z));\n\n result[" + i + "] =\n getChannel(getA(inputRC.x, inputRC.y, inputRC.z), inputRCInnerDims);\n " + (i > 0 ? '}' : '') + "\n "; + } + this.userCode = "\n " + getReshapedInputCoords(inputShape) + "\n " + getFlatIndexFrom3D(outputShape) + "\n\n void main() {\n ivec3 rc = getOutputCoords();\n\n vec4 result = vec4(0.);\n\n ivec3 thisRC;\n int rows = " + outputShape[1] + ";\n int cols = " + outputShape[2] + ";\n\n " + mainLoop + "\n\n setOutput(result);\n }\n "; + } + return ReshapePackedProgram; + }()); + function getReshapedInputCoords(shape) { + var coordsFromIndexSnippet = getLogicalCoordinatesFromFlatIndex(['r', 'c', 'd'], shape); + return "\n ivec3 inputCoordsFromReshapedOutCoords(int index) {\n " + coordsFromIndexSnippet + "\n return ivec3(r, c, d);\n }\n "; + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ResizeBilinearBackpropProgram = /** @class */ (function () { + function ResizeBilinearBackpropProgram(dy, x, alignCorners) { + this.variableNames = ['dy']; + this.outputShape = []; + this.outputShape = x.shape; + var _a = x.shape, xHeight = _a[1], xWidth = _a[2]; + var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; + // In the backwards pass, we want to find the pixels that were generated for + // each pixel in the input image the forward pass and add the corresponding + // coefficient from dy to the gradient (with some interpolation). + var effectiveXSize = [ + (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, + (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth + ]; + var effectiveYSize = [ + (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, + (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth + ]; + var heightScale = effectiveXSize[0] / effectiveYSize[0]; + var widthScale = effectiveXSize[1] / effectiveYSize[1]; + var invHeightScale = 1 / heightScale; + var invWidthScale = 1 / widthScale; + // This defines the size of the window of values around a particular + // index in dy that we want to search for contributions to dx. + var winHeight = (Math.ceil(invHeightScale) * 2) + 2; + var winWidth = (Math.ceil(invWidthScale) * 2) + 2; + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(startRLerp - float(winHeight / 2));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(startCLerp - float(winWidth / 2));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float dxR = float(dyR) * heightScale;\n int topDxRIndex = int(floor(dxR));\n int bottomDxRIndex = int(min(ceil(dxR), " + (xHeight - 1) + ".0));\n float dxRLerp = dxR - float(topDxRIndex);\n float inverseDxRLerp = 1.0 - dxRLerp;\n\n float dxC = float(dyC) * widthScale;\n int leftDxCIndex = int(floor(dxC));\n int rightDxCIndex = int(min(ceil(dxC), " + (xWidth - 1) + ".0));\n float dxCLerp = dxC - float(leftDxCIndex);\n float inverseDxCLerp = 1.0 - dxCLerp;\n\n if (r == topDxRIndex && c == leftDxCIndex) {\n // topLeft\n accumulator +=\n getDy(b, dyR, dyC, d) * inverseDxRLerp * inverseDxCLerp;\n }\n\n if (r == topDxRIndex && c == rightDxCIndex) {\n // topRight\n accumulator += getDy(b, dyR, dyC, d) * inverseDxRLerp * dxCLerp;\n }\n\n if (r == bottomDxRIndex && c == leftDxCIndex) {\n // bottomLeft\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * inverseDxCLerp;\n }\n\n if (r == bottomDxRIndex && c == rightDxCIndex) {\n // bottomRight\n accumulator += getDy(b, dyR, dyC, d) * dxRLerp * dxCLerp;\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n "; + } + return ResizeBilinearBackpropProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ResizeBilinearProgram = /** @class */ (function () { + function ResizeBilinearProgram(inputShape, newHeight, newWidth, alignCorners) { + this.variableNames = ['A']; + this.outputShape = []; + var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3]; + this.outputShape = [batch, newHeight, newWidth, depth]; + var effectiveInSize = [ + (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, + (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth + ]; + var effectiveOutSize = [ + (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, + (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth + ]; + this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec2 sourceFloorRC = ivec2(sourceFracIndexRC);\n ivec2 sourceCeilRC = ivec2(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n float topLeft = getA(b, sourceFloorRC.x, sourceFloorRC.y, d);\n float bottomLeft = getA(b, sourceCeilRC.x, sourceFloorRC.y, d);\n float topRight = getA(b, sourceFloorRC.x, sourceCeilRC.y, d);\n float bottomRight = getA(b, sourceCeilRC.x, sourceCeilRC.y, d);\n\n vec2 fracRC = sourceFracIndexRC - vec2(sourceFloorRC);\n\n float top = topLeft + (topRight - topLeft) * fracRC.y;\n float bottom = bottomLeft + (bottomRight - bottomLeft) * fracRC.y;\n float newValue = top + (bottom - top) * fracRC.x;\n\n setOutput(newValue);\n }\n "; + } + return ResizeBilinearProgram; + }()); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ResizeBilinearPackedProgram = /** @class */ (function () { + function ResizeBilinearPackedProgram(inputShape, newHeight, newWidth, alignCorners) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = []; + var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3]; + this.outputShape = [batch, newHeight, newWidth, depth]; + var effectiveInSize = [ + (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, + (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth + ]; + var effectiveOutSize = [ + (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, + (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth + ]; + this.userCode = "\n const vec3 effectiveInputOverOutputRatioRC = vec3(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec3 inputShapeRC = vec3(" + oldHeight + ".0, " + oldWidth + ".0,\n " + oldWidth + ".0);\n\n float getAValue(int b, int r, int c, int d) {\n return getChannel(getA(b, r, c, d), vec2(c, d));\n }\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n // Calculate values for next column in yRC.z.\n ivec3 yRC = coords.yzz + ivec3(0, 0, 1);\n\n // Fractional source index.\n vec3 sourceFracIndexRC = vec3(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the four integer indices.\n ivec3 sourceFloorRC = ivec3(sourceFracIndexRC);\n ivec3 sourceCeilRC = ivec3(\n min(inputShapeRC - 1.0, ceil(sourceFracIndexRC)));\n\n // Should we calculate next column and row elements in 2x2 packed cell.\n bool hasNextCol = d < " + (depth - 1) + ";\n bool hasNextRow = coords.z < " + (newWidth - 1) + ";\n\n // In parallel, construct four corners for all four components in\n // packed 2x2 cell.\n vec4 topLeft = vec4(\n getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 bottomLeft = vec4(\n getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceFloorRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceFloorRC.z, d + 1) : 0.0);\n\n vec4 topRight = vec4(\n getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceFloorRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceFloorRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec4 bottomRight = vec4(\n getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d),\n hasNextCol ? getAValue(b, sourceCeilRC.x, sourceCeilRC.y, d + 1)\n : 0.0,\n hasNextRow ? getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d)\n : 0.0,\n (hasNextRow && hasNextCol) ?\n getAValue(b, sourceCeilRC.x, sourceCeilRC.z, d + 1) : 0.0);\n\n vec3 fracRC = sourceFracIndexRC - vec3(sourceFloorRC);\n\n vec4 top = mix(topLeft, topRight, fracRC.yyzz);\n vec4 bottom = mix(bottomLeft, bottomRight, fracRC.yyzz);\n vec4 newValue = mix(top, bottom, fracRC.x);\n\n setOutput(newValue);\n }\n "; + } + return ResizeBilinearPackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ResizeNearestNeigborBackpropProgram = /** @class */ (function () { + function ResizeNearestNeigborBackpropProgram(dy, x, alignCorners) { + this.variableNames = ['dy']; + this.outputShape = []; + this.outputShape = x.shape; + var _a = x.shape, xHeight = _a[1], xWidth = _a[2]; + var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; + // In the backwards pass, we want to find the pixels that were generated for + // each pixel in the input image the forward pass and add the corresponding + // coefficient from dy to the gradient (with some interpolation). + var effectiveXSize = [ + (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, + (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth + ]; + var effectiveYSize = [ + (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, + (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth + ]; + var heightScale = effectiveXSize[0] / effectiveYSize[0]; + var widthScale = effectiveXSize[1] / effectiveYSize[1]; + var invHeightScale = 1 / heightScale; + var invWidthScale = 1 / widthScale; + // This defines the size of the window of values around a particular + // index in dy that we want to search for contributions to dx. + var winHeight = (Math.ceil(invHeightScale) * 2) + 2; + var winWidth = (Math.ceil(invWidthScale) * 2) + 2; + this.userCode = "\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n int r = coords[1];\n int c = coords[2];\n\n float accumulator = 0.0;\n\n const float heightScale = float(" + heightScale + ");\n const float widthScale = float(" + widthScale + ");\n\n const float invHeightScale = float(" + invHeightScale + ");\n const float invWidthScale = float(" + invWidthScale + ");\n\n const int winHeight = int(" + winHeight + ");\n const int winWidth = int(" + winWidth + ");\n\n // Compute bounds for where in dy we will look\n float startRLerp = floor(float(r) * invHeightScale);\n int startDyR = int(floor(startRLerp - float(winHeight / 2)));\n\n float startCLerp = floor(float(c) * invWidthScale);\n int startDyC = int(floor(startCLerp - float(winWidth / 2)));\n\n // Loop over dy\n for (int dyROffset = 0; dyROffset < winHeight; dyROffset++) {\n int dyR = dyROffset + startDyR;\n\n // Guard against the window exceeding the bounds of dy\n if (dyR < 0 || dyR >= " + yHeight + ") {\n continue;\n }\n\n for (int dyCOffset = 0; dyCOffset < winWidth; dyCOffset++) {\n int dyC = dyCOffset + startDyC;\n\n // Guard against the window exceeding the bounds of dy\n if (dyC < 0 || dyC >= " + yWidth + ") {\n continue;\n }\n\n float sourceFracRow =\n float(" + effectiveXSize[0] + ") *\n (float(dyR) / float(" + effectiveYSize[0] + "));\n\n float sourceFracCol =\n float(" + effectiveXSize[1] + ") *\n (float(dyC) / float(" + effectiveYSize[1] + "));\n\n int sourceNearestRow = int(min(\n float(int(" + xHeight + ") - 1),\n " + alignCorners + " ? float(round(sourceFracRow)) :\n float(floor(sourceFracRow))));\n\n int sourceNearestCol = int(min(\n float(int(" + xWidth + ") - 1),\n " + alignCorners + " ? float(round(sourceFracCol)) :\n float(floor(sourceFracCol))));\n\n if (r == sourceNearestRow && c == sourceNearestCol) {\n accumulator += getDy(b, dyR, dyC, d);\n }\n }\n }\n // End loop over dy\n\n setOutput(accumulator);\n }\n "; + } + return ResizeNearestNeigborBackpropProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ResizeNearestNeighborProgram = /** @class */ (function () { + function ResizeNearestNeighborProgram(inputShape, newHeight, newWidth, alignCorners) { + this.variableNames = ['A']; + this.outputShape = []; + var batch = inputShape[0], oldHeight = inputShape[1], oldWidth = inputShape[2], depth = inputShape[3]; + this.outputShape = [batch, newHeight, newWidth, depth]; + var effectiveInSize = [ + (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, + (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth + ]; + var effectiveOutSize = [ + (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, + (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth + ]; + // When align corners is false, we rounds the value with floor. + var roundBase = alignCorners ? '0.5' : '0.0'; + this.userCode = "\n const vec2 effectiveInputOverOutputRatioRC = vec2(\n " + effectiveInSize[0] / effectiveOutSize[0] + ",\n " + effectiveInSize[1] / effectiveOutSize[1] + ");\n const vec2 inputShapeRC = vec2(" + oldHeight + ".0, " + oldWidth + ".0);\n\n void main() {\n ivec4 coords = getOutputCoords();\n int b = coords[0];\n int d = coords[3];\n ivec2 yRC = coords.yz;\n\n // Fractional source index.\n vec2 sourceFracIndexRC = vec2(yRC) * effectiveInputOverOutputRatioRC;\n\n // Compute the coordinators of nearest neighbor point.\n ivec2 sourceNearestRC = ivec2(\n min(inputShapeRC - 1.0, floor(sourceFracIndexRC + " + roundBase + ")));\n\n float newValue = getA(b, sourceNearestRC.x, sourceNearestRC.y, d);\n\n setOutput(newValue);\n }\n "; + } + return ResizeNearestNeighborProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ReverseProgram = /** @class */ (function () { + function ReverseProgram(xShape, axis) { + this.variableNames = ['x']; + var rank = xShape.length; + if (rank > 4) { + throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported"); + } + this.outputShape = xShape; + if (rank === 1) { + this.userCode = "\n void main() {\n int coord = getOutputCoords();\n setOutput(getX(" + xShape[0] + " - coord - 1));\n }\n "; + return; + } + var getInCoord = function (i) { + if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { + return xShape[i] + " - coords[" + i + "] - 1"; + } + return "coords[" + i + "]"; + }; + var inCoords = xShape.map(function (_, i) { return getInCoord(i); }).join(','); + var type = getCoordsDataType(rank); + this.userCode = "\n void main() {\n " + type + " coords = getOutputCoords();\n setOutput(getX(" + inCoords + "));\n }\n "; + } + return ReverseProgram; + }()); + + /** + * @license + * Copyright 2019 Google LLC All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ReversePackedProgram = /** @class */ (function () { + function ReversePackedProgram(xShape, axis) { + this.variableNames = ['x']; + this.packedInputs = true; + this.packedOutput = true; + var rank = xShape.length; + if (rank > 4) { + throw new Error("WebGL backend: Reverse of rank-" + rank + " tensor is not yet supported"); + } + this.outputShape = xShape; + var channels = getChannels('rc', rank); + var nextColumn = channels[rank - 1] + " + 1 < " + this.outputShape[rank - 1]; + var nextRow = channels[rank - 2] + " + 1 < " + this.outputShape[rank - 2]; + var type = getCoordsDataType(rank); + if (rank === 1) { + this.userCode = "\n void main(){\n int rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = getChannel(getX(" + xShape[0] + " - rc - 1),\n " + xShape[0] + " - rc - 1);\n if(" + nextColumn + "){\n result.g = getChannel(getX(" + xShape[0] + " - (rc + 1) - 1),\n " + xShape[0] + " - (rc + 1) - 1);\n }\n setOutput(result);\n }\n "; + } + else { + this.userCode = "\n void main() {\n " + type + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result.r = " + getR(channels.slice()) + ";\n if(" + nextColumn + "){\n result.g = " + getG(channels.slice()) + ";\n }\n if(" + nextRow + ") {\n result.b = " + getB(channels.slice()) + ";\n if(" + nextColumn + ") {\n result.a = " + getA(channels.slice()) + ";\n }\n }\n setOutput(result);\n }\n "; + } + function getR(channels) { + return getChannel(channels); + } + function getG(channels) { + channels[rank - 1] = '(' + channels[rank - 1] + " + 1)"; + return getChannel(channels); + } + function getB(channels) { + channels[rank - 2] = '(' + channels[rank - 2] + " + 1)"; + return getChannel(channels); + } + function getA(channels) { + channels[rank - 1] = '(' + channels[rank - 1] + " + 1)"; + channels[rank - 2] = '(' + channels[rank - 2] + " + 1)"; + return getChannel(channels); + } + function getChannel(channels) { + var inCoordsArray = xShape.map(function (_, i) { return getInCoord(i, channels); }); + var inCoords = inCoordsArray.join(','); + var innerDims = inCoordsArray.slice(-2).join(','); + return "getChannel(getX(" + inCoords + "), vec2(" + innerDims + "))"; + } + function getInCoord(i, channels1) { + if (axis.indexOf(i) !== -1 && xShape[i] !== 1) { + return xShape[i] + " - " + channels1[i] + " - 1"; + } + else { + return "" + channels1[i]; + } + } + } + return ReversePackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ScatterProgram = /** @class */ (function () { + function ScatterProgram(updateSize, sliceDim, indicesRank, updatesRank, strides, shape, summingDupeIndex) { + if (summingDupeIndex === void 0) { summingDupeIndex = true; } + this.variableNames = ['updates', 'indices', 'defaultValue']; + this.outputShape = shape; + var stridesType = getCoordsDataType(strides.length); + var dtype = getCoordsDataType(shape.length); + var indicesString = ''; + if (indicesRank === 1) { + indicesString = 'i'; + } + else if (indicesRank === 2) { + indicesString = 'i, j'; + } + var indicesSnippet = "getIndices(" + indicesString + ")"; + var updatesString = ''; + if (updatesRank === 1) { + updatesString = 'i'; + } + else if (updatesRank === 2) { + updatesString = 'i, coords[1]'; + } + var updatesSnippet = "getUpdates(" + updatesString + ")"; + var strideString = sliceDim > 1 ? 'strides[j]' : 'strides'; + this.userCode = "\n " + stridesType + " strides = " + stridesType + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n float sum = 0.0;\n bool found = false;\n for (int i = 0; i < " + updateSize + "; i++) {\n int flattenedIndex = 0;\n for (int j = 0; j < " + sliceDim + "; j++) {\n int index = round(" + indicesSnippet + ");\n flattenedIndex += index * " + strideString + ";\n }\n if (flattenedIndex == coords[0]) {\n sum += " + updatesSnippet + ";\n found = true;\n }\n }\n setOutput(mix(getDefaultValue(), sum, float(found)));\n }\n "; + } + return ScatterProgram; + }()); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var SegmentOpProgram = /** @class */ (function () { + function SegmentOpProgram(segOpInfo, segOpType) { + this.variableNames = ['x', 'segmentIds']; + var windowSize = segOpInfo.windowSize; + var batchSize = segOpInfo.batchSize; + var inSize = segOpInfo.inSize; + var numSegments = segOpInfo.numSegments; + var outSize = numSegments * Math.ceil(inSize / windowSize); + this.outputShape = [batchSize, outSize]; + var initializationValue = '0.0'; + var returnValue = "sumValue"; + var windowSizeNearestVec4 = Math.floor(windowSize / 4) * 4; + var windowSizeVec4Remainder = windowSize % 4; + var updateSnippet = "\n sumValue += dot(values, segFilter);\n "; + var checkValueOutOfBounds = ''; + if (inSize % windowSize > 0) { + checkValueOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return initializationValue;\n }\n "; + } + var checkSegmentIdOutOfBounds = ''; + if (inSize % windowSize > 0) { + checkSegmentIdOutOfBounds = "\n if (inIdx < 0 || inIdx >= " + inSize + ") {\n return -1.0;\n }\n "; + } + this.userCode = "\n const float initializationValue = " + initializationValue + ";\n\n float getValue(int batch, int inIdx) {\n " + checkValueOutOfBounds + "\n return getX(batch, inIdx);\n }\n\n float getSegmentIdAtIndex(int inIdx) {\n " + checkSegmentIdOutOfBounds + "\n return getSegmentIds(inIdx);\n }\n\n void main() {\n ivec2 coords = getOutputCoords();\n int batch = coords[0];\n int outIdx = coords[1];\n int inOffset = int(floor(float(outIdx) / float(\n " + numSegments + ")) * float(" + windowSize + "));\n int currentSeg = int(mod(float(outIdx), float(" + numSegments + ")));\n\n float sumValue = 0.0;\n\n for (int i = 0; i < " + windowSizeNearestVec4 + "; i += 4) {\n int inIdx = inOffset + i;\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n getValue(batch, inIdx + 3)\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 3)) == currentSeg ? 1 : 0\n );\n\n " + updateSnippet + "\n }\n\n int inIdx = inOffset + " + windowSizeNearestVec4 + ";\n if (" + (windowSizeVec4Remainder === 1) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n initializationValue,\n initializationValue,\n initializationValue\n );\n\n int inIdxSeg = int(getSegmentIdAtIndex(inIdx));\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 2) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n initializationValue,\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n 0,\n 0\n );\n\n " + updateSnippet + "\n } else if (" + (windowSizeVec4Remainder === 3) + ") {\n vec4 values = vec4(\n getValue(batch, inIdx),\n getValue(batch, inIdx + 1),\n getValue(batch, inIdx + 2),\n initializationValue\n );\n\n vec4 segFilter = vec4(\n int(getSegmentIdAtIndex(inIdx)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 1)) == currentSeg ? 1 : 0,\n int(getSegmentIdAtIndex(inIdx + 2)) == currentSeg ? 1 : 0,\n 0\n );\n\n " + updateSnippet + "\n }\n setOutput(" + returnValue + ");\n }\n "; + } + return SegmentOpProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var SelectProgram = /** @class */ (function () { + function SelectProgram(cRank, shape, rank) { + this.variableNames = ['c', 'a', 'b']; + this.outputShape = shape; + var cCoords; + var abCoords; + if (rank > 4) { + throw Error("Where for rank " + rank + " is not yet supported"); + } + if (rank === 1) { + abCoords = "resRC"; + cCoords = "resRC"; + } + else { + var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w']; + var cCoordVars = []; + var abCoordVars = []; + for (var i = 0; i < shape.length; i++) { + abCoordVars.push("" + currentCoords[i]); + if (i < cRank) { + cCoordVars.push("" + currentCoords[i]); + } + } + cCoords = cCoordVars.join(); + abCoords = abCoordVars.join(); + } + var dtype = getCoordsDataType(rank); + this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n float cVal = getC(" + cCoords + ");\n if (cVal >= 1.0) {\n setOutput(getA(" + abCoords + "));\n } else {\n setOutput(getB(" + abCoords + "));\n }\n }\n "; + } + return SelectProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var SliceProgram = /** @class */ (function () { + function SliceProgram(destSize) { + this.variableNames = ['source']; + this.outputShape = destSize; + this.rank = destSize.length; + var dtype = getCoordsDataType(this.rank); + var uniformPart = "uniform int start[" + this.rank + "];"; + var sourceCoords = getCoords$1(this.rank); + var body; + var coordSum = destSize.map(function (_, i) { + return "sourceLoc." + coords[i] + " = start[" + i + "] + coords." + coords[i] + ";"; + }); + body = "\n " + dtype + " sourceLoc;\n " + dtype + " coords = getOutputCoords();\n " + coordSum.join('\n') + "\n "; + this.userCode = "\n " + uniformPart + "\n void main() {\n " + body + "\n setOutput(getSource(" + sourceCoords + "));\n }\n "; + } + SliceProgram.prototype.getCustomSetupFunc = function (start) { + var _this = this; + if (start.length !== this.rank) { + throw Error("The rank (" + this.rank + ") of the program must match the " + + ("length of start (" + start.length + ")")); + } + return function (gpgpu, webGLProgram) { + if (_this.startLoc == null) { + _this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start'); + if (_this.startLoc == null) { + // This means the compiler has optimized and realized it doesn't need + // the uniform. + return; + } + } + gpgpu.gl.uniform1iv(_this.startLoc, start); + }; + }; + return SliceProgram; + }()); + var coords = ['x', 'y', 'z', 'w', 'u', 'v']; + function getCoords$1(rank) { + if (rank === 1) { + return 'sourceLoc'; + } + else if (rank <= 6) { + return coords.slice(0, rank).map(function (x) { return 'sourceLoc.' + x; }).join(','); + } + else { + throw Error("Slicing for rank " + rank + " is not yet supported"); + } + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var SlicePackedProgram = /** @class */ (function () { + function SlicePackedProgram(destSize) { + this.variableNames = ['source']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = destSize; + this.rank = destSize.length; + var dtype = getCoordsDataType(this.rank); + var coords = getChannels('coords', this.rank); + var sourceLoc = getChannels('sourceLoc', this.rank); + var innerDims = this.rank === 1 ? 'sourceLoc' : "vec2(" + sourceLoc.slice(-2).join() + ")"; + var getChannel = "getChannel(getSource(" + sourceLoc.join() + "), " + innerDims + ")"; + var upperRow = "\n result.x = " + getChannel + ";\n if (++" + coords[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.y = " + getChannel + ";\n --" + sourceLoc[this.rank - 1] + ";\n }\n "; + var lowerRow = this.rank === 1 ? '' : "\n --" + coords[this.rank - 1] + ";\n if (++" + coords[this.rank - 2] + " < " + destSize[this.rank - 2] + ") {\n ++" + sourceLoc[this.rank - 2] + ";\n result.z = " + getChannel + ";\n if (++" + coords[this.rank - 1] + " < " + destSize[this.rank - 1] + ") {\n ++" + sourceLoc[this.rank - 1] + ";\n result.w = " + getChannel + ";\n }\n }\n "; + var sourceLocSetup = this.rank <= 4 ? + "sourceLoc = coords +\n " + dtype + "(" + destSize.map(function (_, i) { return "start[" + i + "]"; }).join() + ");" : + destSize.map(function (_, i) { return sourceLoc[i] + " = " + coords[i] + " + start[" + i + "];"; }) + .join('\n'); + this.userCode = "\n uniform int start[" + this.rank + "];\n void main() {\n " + dtype + " coords = getOutputCoords();\n " + dtype + " sourceLoc;\n " + sourceLocSetup + "\n vec4 result = vec4(0.);\n " + upperRow + "\n " + lowerRow + "\n setOutput(result);\n }\n "; + } + SlicePackedProgram.prototype.getCustomSetupFunc = function (start) { + var _this = this; + if (start.length !== this.rank) { + throw Error("The rank (" + this.rank + ") of the program must match the " + + ("length of start (" + start.length + ")")); + } + return function (gpgpu, webGLProgram) { + if (_this.startLoc == null) { + _this.startLoc = gpgpu.getUniformLocationNoThrow(webGLProgram, 'start'); + if (_this.startLoc == null) { + // This means the compiler has optimized and realized it doesn't need + // the uniform. + return; + } + } + gpgpu.gl.uniform1iv(_this.startLoc, start); + }; + }; + return SlicePackedProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var StridedSliceProgram = /** @class */ (function () { + function StridedSliceProgram(begin, strides, size) { + this.variableNames = ['x']; + this.outputShape = size; + var rank = size.length; + var inputDtype = getCoordsDataType(size.length); + var dtype = getCoordsDataType(size.length); + var newCoords = ''; + if (rank === 1) { + newCoords = 'coords * strides + begin'; + } + else { + var outputAxis_1 = 0; + newCoords = + size.map(function (_, i) { + outputAxis_1++; + return size.length === 1 ? + "coords * strides[" + i + "] + begin[" + i + "]" : + "coords[" + (outputAxis_1 - 1) + "] * strides[" + i + "] + begin[" + i + "]"; + }) + .join(','); + } + this.userCode = "\n " + inputDtype + " begin = " + inputDtype + "(" + begin + ");\n " + inputDtype + " strides = " + inputDtype + "(" + strides + ");\n\n void main() {\n " + dtype + " coords = getOutputCoords();\n setOutput(getX(" + newCoords + "));\n }\n "; + } + return StridedSliceProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var TextureManager = /** @class */ (function () { + function TextureManager(gpgpu) { + this.gpgpu = gpgpu; + this.numUsedTextures = 0; + this.numFreeTextures = 0; + this.freeTextures = {}; + this.logEnabled = false; + this.usedTextures = {}; + } + TextureManager.prototype.acquireTexture = function (shapeRC, usage, isPacked) { + var physicalTexType = getPhysicalFromLogicalTextureType(usage, isPacked); + var shapeKey = getKeyFromTextureShape(shapeRC, physicalTexType, isPacked); + if (!(shapeKey in this.freeTextures)) { + this.freeTextures[shapeKey] = []; + } + if (!(shapeKey in this.usedTextures)) { + this.usedTextures[shapeKey] = []; + } + if (this.freeTextures[shapeKey].length > 0) { + this.numFreeTextures--; + this.numUsedTextures++; + this.log(); + var newTexture_1 = this.freeTextures[shapeKey].shift(); + this.usedTextures[shapeKey].push(newTexture_1); + return newTexture_1; + } + this.numUsedTextures++; + this.log(); + var newTexture; + if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT32) { + newTexture = this.gpgpu.createPackedMatrixTexture(shapeRC[0], shapeRC[1]); + } + else if (physicalTexType === PhysicalTextureType.PACKED_2X2_FLOAT16) { + newTexture = + this.gpgpu.createFloat16PackedMatrixTexture(shapeRC[0], shapeRC[1]); + } + else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT32) { + newTexture = + this.gpgpu.createFloat32MatrixTexture(shapeRC[0], shapeRC[1]); + } + else if (physicalTexType === PhysicalTextureType.UNPACKED_FLOAT16) { + newTexture = + this.gpgpu.createFloat16MatrixTexture(shapeRC[0], shapeRC[1]); + } + else if (physicalTexType === PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE) { + newTexture = + this.gpgpu.createUnsignedBytesMatrixTexture(shapeRC[0], shapeRC[1]); + } + this.usedTextures[shapeKey].push(newTexture); + return newTexture; + }; + TextureManager.prototype.releaseTexture = function (texture, shape, logicalTexType, isPacked) { + if (this.freeTextures == null) { + // Already disposed. + return; + } + var physicalTexType = getPhysicalFromLogicalTextureType(logicalTexType, isPacked); + var shapeKey = getKeyFromTextureShape(shape, physicalTexType, isPacked); + if (!(shapeKey in this.freeTextures)) { + this.freeTextures[shapeKey] = []; + } + this.freeTextures[shapeKey].push(texture); + this.numFreeTextures++; + this.numUsedTextures--; + var texList = this.usedTextures[shapeKey]; + var texIndex = texList.indexOf(texture); + if (texIndex < 0) { + throw new Error('Cannot release a texture that was never provided by this ' + + 'texture manager'); + } + texList.splice(texIndex, 1); + this.log(); + }; + TextureManager.prototype.log = function () { + if (!this.logEnabled) { + return; + } + var total = this.numFreeTextures + this.numUsedTextures; + console.log('Free/Used', this.numFreeTextures + " / " + this.numUsedTextures, "(" + total + ")"); + }; + TextureManager.prototype.getNumUsedTextures = function () { + return this.numUsedTextures; + }; + TextureManager.prototype.getNumFreeTextures = function () { + return this.numFreeTextures; + }; + TextureManager.prototype.dispose = function () { + var _this = this; + if (this.freeTextures == null) { + // Already disposed. + return; + } + for (var texShape in this.freeTextures) { + this.freeTextures[texShape].forEach(function (tex) { + _this.gpgpu.deleteMatrixTexture(tex); + }); + } + for (var texShape in this.usedTextures) { + this.usedTextures[texShape].forEach(function (tex) { + _this.gpgpu.deleteMatrixTexture(tex); + }); + } + this.freeTextures = null; + this.usedTextures = null; + this.numUsedTextures = 0; + this.numFreeTextures = 0; + }; + return TextureManager; + }()); + function getPhysicalTextureForRendering(isPacked) { + if (env().getBool('WEBGL_RENDER_FLOAT32_ENABLED')) { + if (isPacked) { + return PhysicalTextureType.PACKED_2X2_FLOAT32; + } + return PhysicalTextureType.UNPACKED_FLOAT32; + } + if (isPacked) { + return PhysicalTextureType.PACKED_2X2_FLOAT16; + } + return PhysicalTextureType.UNPACKED_FLOAT16; + } + function getPhysicalFromLogicalTextureType(logicalTexType, isPacked) { + if (logicalTexType === TextureUsage.UPLOAD) { + return PhysicalTextureType.PACKED_2X2_FLOAT32; + } + else if (logicalTexType === TextureUsage.RENDER || logicalTexType == null) { + return getPhysicalTextureForRendering(isPacked); + } + else if (logicalTexType === TextureUsage.DOWNLOAD || + logicalTexType === TextureUsage.PIXELS) { + return PhysicalTextureType.PACKED_4X1_UNSIGNED_BYTE; + } + throw new Error("Unknown logical texture type " + logicalTexType); + } + function getKeyFromTextureShape(shapeRowsCol, physicalTexType, isPacked) { + return shapeRowsCol[0] + "_" + shapeRowsCol[1] + "_" + physicalTexType + "_" + isPacked; + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var TileProgram = /** @class */ (function () { + function TileProgram(aShape, reps) { + this.variableNames = ['A']; + var outputShape = new Array(aShape.length); + for (var i = 0; i < outputShape.length; i++) { + outputShape[i] = aShape[i] * reps[i]; + } + this.outputShape = outputShape; + this.rank = outputShape.length; + var dtype = getCoordsDataType(this.rank); + var sourceCoords = getSourceCoords$2(aShape); + this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + sourceCoords + "));\n }\n "; + } + return TileProgram; + }()); + function getSourceCoords$2(aShape) { + var rank = aShape.length; + if (rank > 5) { + throw Error("Tile for rank " + rank + " is not yet supported"); + } + if (rank === 1) { + return "imod(resRC, " + aShape[0] + ")"; + } + var currentCoords = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u']; + var sourceCoords = []; + for (var i = 0; i < aShape.length; i++) { + sourceCoords.push("imod(" + currentCoords[i] + ", " + aShape[i] + ")"); + } + return sourceCoords.join(); + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var TransposeProgram = /** @class */ (function () { + function TransposeProgram(aShape, newDim) { + this.variableNames = ['A']; + var outputShape = new Array(aShape.length); + for (var i = 0; i < outputShape.length; i++) { + outputShape[i] = aShape[newDim[i]]; + } + this.outputShape = outputShape; + this.rank = outputShape.length; + var dtype = getCoordsDataType(this.rank); + var switched = getSwitchedCoords(newDim); + this.userCode = "\n void main() {\n " + dtype + " resRC = getOutputCoords();\n setOutput(getA(" + switched + "));\n }\n "; + } + return TransposeProgram; + }()); + function getSwitchedCoords(newDim) { + var rank = newDim.length; + if (rank > 6) { + throw Error("Transpose for rank " + rank + " is not yet supported"); + } + var originalOrder = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v']; + var switchedCoords = new Array(rank); + for (var i = 0; i < newDim.length; i++) { + switchedCoords[newDim[i]] = originalOrder[i]; + } + return switchedCoords.join(); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var TransposePackedProgram = /** @class */ (function () { + function TransposePackedProgram(aShape, newDim) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + var outputShape = new Array(aShape.length); + for (var i = 0; i < outputShape.length; i++) { + outputShape[i] = aShape[newDim[i]]; + } + this.outputShape = outputShape; + this.rank = outputShape.length; + if (this.rank > 6) { + throw Error("Packed transpose for rank " + this.rank + " is not yet supported."); + } + var dtype = getCoordsDataType(this.rank); + var outputOrder = getVecChannels('rc', this.rank); + var switchedOrder = new Array(this.rank); + for (var i = 0; i < newDim.length; i++) { + switchedOrder[newDim[i]] = outputOrder[i]; + } + var innerDims = "vec2(" + switchedOrder.slice(-2).join() + ")"; + var nextColumn = "++" + outputOrder[this.rank - 1] + " < " + outputShape[this.rank - 1]; + var getc = "getChannel(getA(" + switchedOrder.join() + "), " + innerDims + ")"; + this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 result = vec4(0.);\n result[0] = " + getc + ";\n if(" + nextColumn + ") {\n result[1] = " + getc + ";\n }\n --" + outputOrder[this.rank - 1] + ";\n if(++" + outputOrder[this.rank - 2] + " < " + outputShape[this.rank - 2] + ") {\n result[2] = " + getc + ";\n if(" + nextColumn + ") {\n result[3] = " + getc + ";\n }\n }\n setOutput(result);\n }\n "; + } + return TransposePackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var ERF_P = 0.3275911; + var ERF_A1 = 0.254829592; + var ERF_A2 = -0.284496736; + var ERF_A3 = 1.421413741; + var ERF_A4 = -1.453152027; + var ERF_A5 = 1.061405429; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var SELU_SCALEALPHA = 1.7580993408473768599402175208123; + var SELU_SCALE = 1.0507009873554804934193349852946; + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var UnaryOpProgram = /** @class */ (function () { + function UnaryOpProgram(aShape, opSnippet) { + this.variableNames = ['A']; + this.outputShape = aShape; + this.userCode = "\n float unaryOperation(float x) {\n " + opSnippet + "\n }\n\n void main() {\n float x = getAAtOutCoords();\n float y = unaryOperation(x);\n\n setOutput(y);\n }\n "; + } + return UnaryOpProgram; + }()); + var CHECK_NAN_SNIPPET$2 = "if (isnan(x)) return x;"; + var LINEAR = "return x;"; + var ABS = "return abs(x);"; + var RELU = CHECK_NAN_SNIPPET$2 + "\n return (x < 0.0) ? 0.0 : x;\n"; + var RELU6 = CHECK_NAN_SNIPPET$2 + "\n return (x < 0.0) ? 0.0 : min(6.0, x);\n"; + var ELU = "return (x >= 0.0) ? x : (exp(x) - 1.0);"; + var SELU = "\n // Stable and Attracting Fixed Point (0, 1) for Normalized Weights.\n // see: https://arxiv.org/abs/1706.02515\n float scaleAlpha = " + SELU_SCALEALPHA + ";\n float scale = " + SELU_SCALE + ";\n return (x >= 0.0) ? scale * x : scaleAlpha * (exp(x) - 1.0);\n"; + function STEP(alpha) { + if (alpha === void 0) { alpha = 0.0; } + return CHECK_NAN_SNIPPET$2 + ("\n return x > 0.0 ? 1.0 : float(" + alpha + ");\n "); + } + var NEG = "return -x;"; + var CEIL = "return ceil(x);"; + var FLOOR = "return floor(x);"; + var SIGN = "\n if (isnan(x)) { return 0.0; }\n return sign(x);\n"; + var IS_NAN = "return float(isnan(x));"; + var IS_INF = "return float(isinf(x));"; + var IS_FINITE = "return float(!isnan(x) && !isinf(x));"; + var ROUND = "\n // OpenGL ES does not support round function.\n // The algorithm is based on banker's rounding.\n float base = floor(x);\n if ((x - base) < 0.5) {\n return floor(x);\n } else if ((x - base) > 0.5) {\n return ceil(x);\n } else {\n if (mod(base, 2.0) == 0.0) {\n return base;\n } else {\n return base + 1.0;\n }\n }\n"; + var EXP = "return exp(x);"; + var EXPM1 = "return exp(x) - 1.0;"; + var LOG = "if (x < 0.0) return NAN;\n return log(x);"; + var LOG1P = "return log(1.0 + x);"; + var SQRT = "return sqrt(x);"; + var RSQRT = "return inversesqrt(x);"; + var SIGMOID = "return 1.0 / (1.0 + exp(-1.0 * x));"; + /** + * mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX + * + * epsilon is the difference between 1.0 and the next representable + * float. For a single precision 32 bit float this should be 2^-23, see: + * https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm + * + * too_large = (x > -threshold) is value above which exp(x) may overflow + * but softplus(x) == x is within machine epsilon + * + * too_small = (x < threshold) is value below which exp(x) may underflow, + * but softplus(x) == exp(x) is within machine epsilon. + */ + var SOFTPLUS = "\n float epsilon = 1.1920928955078125e-7;\n float threshold = log(epsilon) + 2.0;\n\n bool too_large = x > -threshold;\n bool too_small = x < threshold;\n\n float result;\n float exp_x = exp(x);\n\n if (too_large){\n result = x;\n }\n else if (too_small){\n result = exp_x;\n }\n else{\n result = log(exp_x + 1.0);\n }\n return result;\n"; + var SIN = CHECK_NAN_SNIPPET$2 + "\n return sin(x);\n"; + var COS = CHECK_NAN_SNIPPET$2 + "\n return cos(x);\n"; + var TAN = "return tan(x);"; + var ASIN = CHECK_NAN_SNIPPET$2 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return asin(x);\n"; + var ACOS = CHECK_NAN_SNIPPET$2 + "\n if (abs(x) > 1.) {\n return NAN;\n }\n return acos(x);\n"; + var ATAN = CHECK_NAN_SNIPPET$2 + "\n return atan(x);\n"; + var SINH = "\n float e2x = exp(x);\n return (e2x - 1.0 / e2x) / 2.0;\n"; + var COSH = "\n float e2x = exp(-x);\n return (e2x + 1.0 / e2x) / 2.0;\n"; + var TANH = "\n float e2x = exp(-2.0 * abs(x));\n return sign(x) * (1.0 - e2x) / (1.0 + e2x);\n"; + var ASINH = CHECK_NAN_SNIPPET$2 + "return log(x + sqrt(x * x + 1.0));"; + var ACOSH = CHECK_NAN_SNIPPET$2 + "\n if (x < 1.0) return NAN;\n return log(x + sqrt(x * x - 1.0));"; + var ATANH = CHECK_NAN_SNIPPET$2 + "\n if ((x < -1.0) || (x > 1.0)) return NAN;\n return (log(1.0 + x) - log(1.0 - x)) / 2.0;"; + var ERF = "\n // Error function is calculated approximately with elementary function.\n // See \"Handbook of Mathematical Functions with Formulas,\n // Graphs, and Mathematical Tables\", Abramowitz and Stegun.\n float p = " + ERF_P + ";\n float a1 = " + ERF_A1 + ";\n float a2 = " + ERF_A2 + ";\n float a3 = " + ERF_A3 + ";\n float a4 = " + ERF_A4 + ";\n float a5 = " + ERF_A5 + ";\n\n float sign = sign(x);\n x = abs(x);\n float t = 1.0 / (1.0 + p * x);\n return sign * (1.0 - (((((a5*t + a4)*t) + a3)*t + a2)*t + a1)*t*exp(-x*x));\n"; + var SQUARE = "return x * x;"; + var RECIPROCAL = "return 1.0 / x;"; + var LOGICAL_NOT = "return float(!(x >= 1.0));"; + var TO_INT = "return float(int(x));"; + var CLONE = 'return x;'; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var LINEAR$1 = "return x;"; + var LOG$1 = "\n vec4 result = log(x);\n vec4 isNaN = vec4(lessThan(x, vec4(0.0)));\n result.r = isNaN.r == 1.0 ? NAN : result.r;\n result.g = isNaN.g == 1.0 ? NAN : result.g;\n result.b = isNaN.b == 1.0 ? NAN : result.b;\n result.a = isNaN.a == 1.0 ? NAN : result.a;\n\n return result;\n"; + var RELU$1 = "\n vec4 result = x * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; + var RELU6$1 = "\n vec4 result = min(x, vec4(6.)) * vec4(greaterThanEqual(x, vec4(0.0)));\n bvec4 isNaN = isnan(x);\n\n result.r = isNaN.r ? x.r : result.r;\n result.g = isNaN.g ? x.g : result.g;\n result.b = isNaN.b ? x.b : result.b;\n result.a = isNaN.a ? x.a : result.a;\n\n return result;\n"; + var ELU$1 = "\n vec4 result;\n\n result.r = (x.r >= 0.0) ? x.r : (exp(x.r) - 1.0);\n result.g = (x.g >= 0.0) ? x.g : (exp(x.g) - 1.0);\n result.b = (x.b >= 0.0) ? x.b : (exp(x.b) - 1.0);\n result.a = (x.a >= 0.0) ? x.a : (exp(x.a) - 1.0);\n\n return result;\n"; + var UnaryOpPackedProgram = /** @class */ (function () { + function UnaryOpPackedProgram(aShape, opSnippet) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = true; + this.outputShape = aShape; + this.userCode = "\n vec4 unaryOperation(vec4 x) {\n " + opSnippet + "\n }\n\n void main() {\n vec4 x = getAAtOutCoords();\n vec4 y = unaryOperation(x);\n\n setOutput(y);\n }\n "; + } + return UnaryOpPackedProgram; + }()); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var UnpackProgram = /** @class */ (function () { + function UnpackProgram(outputShape) { + this.variableNames = ['A']; + this.packedInputs = true; + this.packedOutput = false; + this.outputShape = outputShape; + var rank = outputShape.length; + var channels = getChannels('rc', rank); + var dtype = getCoordsDataType(rank); + var sourceCoords = getSourceCoords(rank, channels); + var innerDims = channels.slice(-2); + var coords = rank <= 1 ? 'rc' : "vec2(" + innerDims.join(',') + ")"; + this.userCode = "\n void main() {\n " + dtype + " rc = getOutputCoords();\n vec4 packedInput = getA(" + sourceCoords + ");\n\n setOutput(getChannel(packedInput, " + coords + "));\n }\n "; + } + return UnpackProgram; + }()); + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var binaryCaches = {}; + function getBinaryCache(webGLVersion) { + if (webGLVersion in binaryCaches) { + return binaryCaches[webGLVersion]; + } + binaryCaches[webGLVersion] = {}; + return binaryCaches[webGLVersion]; + } + function mapActivationToShaderProgram(activation, packed) { + if (packed === void 0) { packed = false; } + if (activation === 'linear') { + if (packed) { + return LINEAR$1; + } + return LINEAR; + } + else if (activation === 'relu') { + if (packed) { + return RELU$1; + } + return RELU; + } + else if (activation === 'elu') { + if (packed) { + return ELU$1; + } + return ELU; + } + else if (activation === 'relu6') { + if (packed) { + return RELU6$1; + } + return RELU6; + } + else if (activation === 'prelu') { + if (packed) { + return PRELU$1; + } + return PRELU; + } + throw new Error("Activation " + activation + " has not been implemented for the WebGL backend."); + } + // Empirically determined constant used to determine size threshold for handing + // off execution to the CPU. + var CPU_HANDOFF_SIZE_THRESHOLD = 128; + // Empirically determined constant used to decide the number of MB on GPU + // before we warn about high memory use. The MB are this constant * screen area + // * dpi / 1024 / 1024. + var BEFORE_PAGING_CONSTANT = 600; + function numMBBeforeWarning() { + if (env().global.screen == null) { + return 1024; // 1 GB. + } + return (env().global.screen.height * env().global.screen.width * + window.devicePixelRatio) * + BEFORE_PAGING_CONSTANT / 1024 / 1024; + } + // Empirically determined minimal shared dimension in matmul before we forward + // to a.mul(b).sum() in order to take advantage of GPU parallelism. See + // https://github.com/tensorflow/tfjs-core/pull/1379 for benchmarks. + var MATMUL_SHARED_DIM_THRESHOLD = 1000; + var MathBackendWebGL = /** @class */ (function (_super) { + __extends(MathBackendWebGL, _super); + function MathBackendWebGL(gpgpu) { + var _this = _super.call(this) || this; + _this.gpgpu = gpgpu; + // Maps data ids that have a pending read operation, to list of subscribers. + _this.pendingRead = new WeakMap(); + // List of data ids that are scheduled for disposal, but are waiting on a + // pending read operation. + _this.pendingDisposal = new WeakSet(); + // Used to count the number of 'shallow' sliced tensors that point to the + // same data id. + _this.dataRefCount = new WeakMap(); + _this.numBytesInGPU = 0; + // Accumulated time spent (including blocking) in uploading data to webgl. + _this.uploadWaitMs = 0; + // Accumulated time spent (including blocking in downloading data from webgl. + _this.downloadWaitMs = 0; + _this.warnedAboutMemory = false; + _this.pendingDeletes = 0; + _this.disposed = false; + if (!env().getBool('HAS_WEBGL')) { + throw new Error('WebGL is not supported on this device'); + } + if (gpgpu == null) { + var gl = getWebGLContext(env().getNumber('WEBGL_VERSION')); + _this.binaryCache = getBinaryCache(env().getNumber('WEBGL_VERSION')); + _this.gpgpu = new GPGPUContext(gl); + _this.canvas = gl.canvas; + _this.gpgpuCreatedLocally = true; + } + else { + _this.binaryCache = {}; + _this.gpgpuCreatedLocally = false; + _this.canvas = gpgpu.gl.canvas; + } + _this.textureManager = new TextureManager(_this.gpgpu); + _this.numMBBeforeWarning = numMBBeforeWarning(); + _this.texData = new DataStorage(_this, ENGINE); + return _this; + } + MathBackendWebGL.prototype.numDataIds = function () { + return this.texData.numDataIds() + + (this.cpuBackend ? this.cpuBackend.numDataIds() : 0) - + this.pendingDeletes; + }; + MathBackendWebGL.prototype.fromPixels = function (pixels, numChannels) { + if (pixels == null) { + throw new Error('pixels passed to tf.browser.fromPixels() can not be null'); + } + var isCanvas = (typeof (OffscreenCanvas) !== 'undefined' && + pixels instanceof OffscreenCanvas) || + (typeof (HTMLCanvasElement) !== 'undefined' && + pixels instanceof HTMLCanvasElement); + var isPixelData = pixels.data instanceof Uint8Array; + var isImageData = typeof (ImageData) !== 'undefined' && pixels instanceof ImageData; + var isVideo = typeof (HTMLVideoElement) !== 'undefined' && + pixels instanceof HTMLVideoElement; + var isImage = typeof (HTMLImageElement) !== 'undefined' && + pixels instanceof HTMLImageElement; + var _a = isVideo ? + [ + pixels.videoWidth, + pixels.videoHeight + ] : + [pixels.width, pixels.height], width = _a[0], height = _a[1]; + var texShape = [height, width]; + var outShape = [height, width, numChannels]; + if (!isCanvas && !isPixelData && !isImageData && !isVideo && !isImage) { + throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' + + "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData " + + "in browser, or OffscreenCanvas, ImageData in webworker" + + " or {data: Uint32Array, width: number, height: number}, " + + ("but was " + pixels.constructor.name)); + } + if (isImage || isVideo) { + if (this.fromPixels2DContext == null) { + //@ts-ignore + this.fromPixels2DContext = + createCanvas(env().getNumber('WEBGL_VERSION')).getContext('2d'); + } + this.fromPixels2DContext.canvas.width = width; + this.fromPixels2DContext.canvas.height = height; + this.fromPixels2DContext.drawImage(pixels, 0, 0, width, height); + //@ts-ignore + pixels = this.fromPixels2DContext.canvas; + } + var tempPixelHandle = this.makeTensorInfo(texShape, 'int32'); + // This is a byte texture with pixels. + this.texData.get(tempPixelHandle.dataId).usage = TextureUsage.PIXELS; + this.gpgpu.uploadPixelDataToTexture(this.getTexture(tempPixelHandle.dataId), pixels); + var program, res; + if (env().getBool('WEBGL_PACK')) { + program = new FromPixelsPackedProgram(outShape); + res = this.compileAndRun(program, [tempPixelHandle]); + } + else { + program = new FromPixelsProgram(outShape); + res = this.compileAndRun(program, [tempPixelHandle]); + } + this.disposeData(tempPixelHandle.dataId); + return res; + }; + MathBackendWebGL.prototype.write = function (values, shape, dtype) { + if (env().getBool('DEBUG')) { + this.checkNumericalProblems(values); + } + if (dtype === 'complex64' && values != null) { + throw new Error("Cannot write to a complex64 dtype. " + + "Please use tf.complex(real, imag)."); + } + var dataId = {}; + this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: TextureUsage.UPLOAD }); + return dataId; + }; + MathBackendWebGL.prototype.move = function (dataId, values, shape, dtype) { + if (env().getBool('DEBUG')) { + this.checkNumericalProblems(values); + } + if (dtype === 'complex64') { + throw new Error("Cannot write to a complex64 dtype. " + + "Please use tf.complex(real, imag)."); + } + this.texData.set(dataId, { shape: shape, dtype: dtype, values: values, usage: TextureUsage.UPLOAD }); + }; + MathBackendWebGL.prototype.readSync = function (dataId) { + var texData = this.texData.get(dataId); + var values = texData.values, dtype = texData.dtype, complexTensors = texData.complexTensors, slice = texData.slice, shape = texData.shape, isPacked = texData.isPacked; + if (slice != null) { + var program = void 0; + if (isPacked) { + program = new UnaryOpPackedProgram(shape, CLONE); + } + else { + program = new UnaryOpProgram(shape, CLONE); + } + var res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype); + var data = this.readSync(res.dataId); + this.disposeData(res.dataId); + return data; + } + if (values != null) { + return this.convertAndCacheOnCPU(dataId); + } + if (dtype === 'string') { + return values; + } + var shouldTimeProgram = this.activeTimers != null; + var start; + if (shouldTimeProgram) { + start = now(); + } + var result; + if (dtype === 'complex64') { + var realValues = complexTensors.real.dataSync(); + var imagValues = complexTensors.imag.dataSync(); + result = mergeRealAndImagArrays(realValues, imagValues); + } + else { + result = this.getValuesFromTexture(dataId); + } + if (shouldTimeProgram) { + this.downloadWaitMs += now() - start; + } + return this.convertAndCacheOnCPU(dataId, result); + }; + MathBackendWebGL.prototype.read = function (dataId) { + return __awaiter(this, void 0, void 0, function () { + var subscribers_1, texData, values, shape, slice, dtype, complexTensors, isPacked, program, res, data, buffer, tmpDownloadTarget, tmpData, vals, ps, realValues, imagValues, size, dTypeVals, subscribers; + var _a; + return __generator(this, function (_b) { + switch (_b.label) { + case 0: + if (this.pendingRead.has(dataId)) { + subscribers_1 = this.pendingRead.get(dataId); + return [2 /*return*/, new Promise(function (resolve) { return subscribers_1.push(resolve); })]; + } + texData = this.texData.get(dataId); + values = texData.values, shape = texData.shape, slice = texData.slice, dtype = texData.dtype, complexTensors = texData.complexTensors, isPacked = texData.isPacked; + if (slice != null) { + program = void 0; + if (isPacked) { + program = new UnaryOpPackedProgram(shape, CLONE); + } + else { + program = new UnaryOpProgram(shape, CLONE); + } + res = this.runWebGLProgram(program, [{ dataId: dataId, shape: shape, dtype: dtype }], dtype); + data = this.read(res.dataId); + this.disposeData(res.dataId); + return [2 /*return*/, data]; + } + if (values != null) { + return [2 /*return*/, this.convertAndCacheOnCPU(dataId)]; + } + if (!env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED') && + env().getNumber('WEBGL_VERSION') === 2) { + throw new Error("tensor.data() with WEBGL_DOWNLOAD_FLOAT_ENABLED=false and " + + "WEBGL_VERSION=2 not yet supported."); + } + buffer = null; + if (dtype !== 'complex64' && env().get('WEBGL_BUFFER_SUPPORTED')) { + // Possibly copy the texture into a buffer before inserting a fence. + tmpDownloadTarget = this.decode(dataId); + tmpData = this.texData.get(tmpDownloadTarget.dataId); + buffer = (_a = this.gpgpu).createBufferFromTexture.apply(_a, [tmpData.texture].concat(getDenseTexShape(shape))); + } + this.pendingRead.set(dataId, []); + if (!(dtype !== 'complex64')) return [3 /*break*/, 2]; + // Create a fence and wait for it to resolve. + return [4 /*yield*/, this.gpgpu.createAndWaitForFence()]; + case 1: + // Create a fence and wait for it to resolve. + _b.sent(); + _b.label = 2; + case 2: + if (!(dtype === 'complex64')) return [3 /*break*/, 4]; + return [4 /*yield*/, Promise.all([complexTensors.real.data(), complexTensors.imag.data()])]; + case 3: + ps = _b.sent(); + realValues = ps[0]; + imagValues = ps[1]; + vals = mergeRealAndImagArrays(realValues, imagValues); + return [3 /*break*/, 5]; + case 4: + if (buffer == null) { + vals = this.getValuesFromTexture(dataId); + } + else { + size = sizeFromShape(shape); + vals = this.gpgpu.downloadFloat32MatrixFromBuffer(buffer, size); + } + _b.label = 5; + case 5: + if (tmpDownloadTarget != null) { + this.disposeData(tmpDownloadTarget.dataId); + } + dTypeVals = this.convertAndCacheOnCPU(dataId, vals); + subscribers = this.pendingRead.get(dataId); + this.pendingRead.delete(dataId); + // Notify all pending reads. + subscribers.forEach(function (resolve) { return resolve(dTypeVals); }); + if (this.pendingDisposal.has(dataId)) { + this.pendingDisposal.delete(dataId); + this.disposeData(dataId); + this.pendingDeletes--; + } + return [2 /*return*/, dTypeVals]; + } + }); + }); + }; + MathBackendWebGL.prototype.checkNumericalProblems = function (values) { + if (values == null) { + return; + } + for (var i = 0; i < values.length; i++) { + var num = values[i]; + if (!canBeRepresented(num)) { + if (env().getBool('WEBGL_RENDER_FLOAT32_CAPABLE')) { + throw Error("The value " + num + " cannot be represented with your " + + "current settings. Consider enabling float32 rendering: " + + "'tf.env().set('WEBGL_RENDER_FLOAT32_ENABLED', true);'"); + } + throw Error("The value " + num + " cannot be represented on this device."); + } + } + }; + MathBackendWebGL.prototype.getValuesFromTexture = function (dataId) { + var _a; + var _b = this.texData.get(dataId), shape = _b.shape, dtype = _b.dtype, isPacked = _b.isPacked; + var size = sizeFromShape(shape); + if (env().getBool('WEBGL_DOWNLOAD_FLOAT_ENABLED')) { + var tmpTarget = this.decode(dataId); + var tmpData_1 = this.texData.get(tmpTarget.dataId); + var vals_1 = (_a = this.gpgpu).downloadMatrixFromPackedTexture.apply(_a, [tmpData_1.texture].concat(getDenseTexShape(shape))).subarray(0, size); + this.disposeData(tmpTarget.dataId); + return vals_1; + } + var shouldUsePackedProgram = env().getBool('WEBGL_PACK') && isPacked === true; + var outputShape = shouldUsePackedProgram ? getShapeAs3D(shape) : shape; + var program = shouldUsePackedProgram ? + new EncodeFloatPackedProgram(outputShape) : + new EncodeFloatProgram(outputShape); + var output = this.runWebGLProgram(program, [{ shape: outputShape, dtype: dtype, dataId: dataId }], 'float32'); + var tmpData = this.texData.get(output.dataId); + var vals = this.gpgpu + .downloadByteEncodedFloatMatrixFromOutputTexture(tmpData.texture, tmpData.texShape[0], tmpData.texShape[1]) + .subarray(0, size); + this.disposeData(output.dataId); + return vals; + }; + MathBackendWebGL.prototype.time = function (f) { + return __awaiter(this, void 0, void 0, function () { + var oldActiveTimers, newActiveTimers, outerMostTime, flattenedActiveTimerQueries, flattenedActiveTimerNames, kernelMs, res; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + oldActiveTimers = this.activeTimers; + newActiveTimers = []; + outerMostTime = false; + if (this.programTimersStack == null) { + this.programTimersStack = newActiveTimers; + outerMostTime = true; + } + else { + this.activeTimers.push(newActiveTimers); + } + this.activeTimers = newActiveTimers; + f(); + flattenedActiveTimerQueries = flatten(this.activeTimers.map(function (d) { return d.query; })) + .filter(function (d) { return d != null; }); + flattenedActiveTimerNames = flatten(this.activeTimers.map(function (d) { return d.name; })) + .filter(function (d) { return d != null; }); + this.activeTimers = oldActiveTimers; + if (outerMostTime) { + this.programTimersStack = null; + } + return [4 /*yield*/, Promise.all(flattenedActiveTimerQueries)]; + case 1: + kernelMs = _a.sent(); + res = { + uploadWaitMs: this.uploadWaitMs, + downloadWaitMs: this.downloadWaitMs, + kernelMs: sum(kernelMs), + getExtraProfileInfo: function () { + return kernelMs.map(function (d, i) { return ({ name: flattenedActiveTimerNames[i], ms: d }); }) + .map(function (d) { return d.name + ": " + d.ms; }) + .join(', '); + }, + wallMs: null // will be filled by the engine + }; + this.uploadWaitMs = 0; + this.downloadWaitMs = 0; + return [2 /*return*/, res]; + } + }); + }); + }; + MathBackendWebGL.prototype.memory = function () { + return { unreliable: false, numBytesInGPU: this.numBytesInGPU }; + }; + MathBackendWebGL.prototype.startTimer = function () { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + return this.gpgpu.beginQuery(); + } + return { startMs: now(), endMs: null }; + }; + MathBackendWebGL.prototype.endTimer = function (query) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + this.gpgpu.endQuery(); + return query; + } + query.endMs = now(); + return query; + }; + MathBackendWebGL.prototype.getQueryTime = function (query) { + return __awaiter(this, void 0, void 0, function () { + var timerQuery; + return __generator(this, function (_a) { + if (env().getNumber('WEBGL_DISJOINT_QUERY_TIMER_EXTENSION_VERSION') > 0) { + return [2 /*return*/, this.gpgpu.waitForQueryAndGetTime(query)]; + } + timerQuery = query; + return [2 /*return*/, timerQuery.endMs - timerQuery.startMs]; + }); + }); + }; + MathBackendWebGL.prototype.disposeData = function (dataId) { + if (this.pendingDisposal.has(dataId)) { + return; + } + if (this.pendingRead.has(dataId)) { + this.pendingDisposal.add(dataId); + this.pendingDeletes++; + return; + } + // No-op if already disposed. + if (!this.texData.has(dataId)) { + return; + } + this.releaseGPUData(dataId); + var complexTensors = this.texData.get(dataId).complexTensors; + if (complexTensors != null) { + complexTensors.real.dispose(); + complexTensors.imag.dispose(); + } + this.texData.delete(dataId); + }; + MathBackendWebGL.prototype.releaseGPUData = function (dataId) { + var _a = this.texData.get(dataId), texture = _a.texture, dtype = _a.dtype, texShape = _a.texShape, usage = _a.usage, isPacked = _a.isPacked, slice = _a.slice; + var key = slice && slice.origDataId || dataId; + var refCount = this.dataRefCount.get(key); + if (refCount > 1) { + this.dataRefCount.set(key, refCount - 1); + } + else { + this.dataRefCount.delete(key); + if (texture != null) { + this.numBytesInGPU -= this.computeBytes(texShape, dtype); + this.textureManager.releaseTexture(texture, texShape, usage, isPacked); + } + } + var texData = this.texData.get(dataId); + texData.texture = null; + texData.texShape = null; + texData.isPacked = false; + texData.slice = null; + }; + MathBackendWebGL.prototype.getTexture = function (dataId) { + this.uploadToGPU(dataId); + return this.texData.get(dataId).texture; + }; + /** + * Returns internal information for the specific data bucket. Used in unit + * tests. + */ + MathBackendWebGL.prototype.getDataInfo = function (dataId) { + return this.texData.get(dataId); + }; + MathBackendWebGL.prototype.getCPUBackend = function () { + if (!env().getBool('WEBGL_CPU_FORWARD')) { + return null; + } + if (this.cpuBackend == null) { + this.cpuBackend = ENGINE.findBackend('cpu'); + } + return this.cpuBackend; + }; + /* + Tests whether all the inputs to an op are small and on the CPU. This heuristic + determines when it would be faster to execute a kernel on the CPU. WebGL + kernels opt into running this check and forwarding when appropriate. + TODO(https://github.com/tensorflow/tfjs/issues/872): Develop a more + sustainable strategy for optimizing backend execution of ops. + */ + MathBackendWebGL.prototype.shouldExecuteOnCPU = function (inputs, sizeThreshold) { + var _this = this; + if (sizeThreshold === void 0) { sizeThreshold = CPU_HANDOFF_SIZE_THRESHOLD; } + return this.getCPUBackend() != null && + inputs.every(function (input) { return _this.texData.get(input.dataId).texture == null && + input.size < sizeThreshold; }); + }; + MathBackendWebGL.prototype.getGPGPUContext = function () { + return this.gpgpu; + }; + MathBackendWebGL.prototype.complex = function (real, imag) { + var result = this.makeOutput(real.shape, 'complex64'); + var resultData = this.texData.get(result.dataId); + // The backend owns the reference to the underlying real and imaginary + // clones. These will explicitly get disposed when the complex tensor is + // disposed. + resultData.complexTensors = { + real: ENGINE.keep(real.clone()), + imag: ENGINE.keep(imag.clone()) + }; + return result; + }; + MathBackendWebGL.prototype.real = function (input) { + var resultData = this.texData.get(input.dataId); + return resultData.complexTensors.real.clone(); + }; + MathBackendWebGL.prototype.imag = function (input) { + var resultData = this.texData.get(input.dataId); + return resultData.complexTensors.imag.clone(); + }; + MathBackendWebGL.prototype.slice = function (x, begin, size) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.slice(x, begin, size); + } + // Short-circuit computation if the slice is zero-sized. + if (sizeFromShape(size) === 0) { + return tensor([], size, x.dtype); + } + var isPacked = this.texData.get(x.dataId).isPacked; + var isContinous = isSliceContinous(x.shape, begin, size); + if (isPacked || !isContinous) { + var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + new SlicePackedProgram(size) : + new SliceProgram(size); + var customSetup = program.getCustomSetupFunc(begin); + return this.compileAndRun(program, [x], null, customSetup); + } + this.uploadToGPU(x.dataId); + return this.shallowSlice(x, begin, size); + }; + MathBackendWebGL.prototype.shallowSlice = function (x, begin, size) { + var xTexData = this.texData.get(x.dataId); + var t = this.makeOutput(size, x.dtype); + var newTexData = this.texData.get(t.dataId); + // Copy texture data from the original tensor. + Object.assign(newTexData, xTexData); + newTexData.shape = size; + newTexData.dtype = x.dtype; + var flatOffset = computeFlatOffset(begin, x.strides); + if (xTexData.slice) { + // We are slicing an already sliced tensor, so we have to accumulate + // the offset. + flatOffset += xTexData.slice.flatOffset; + } + newTexData.slice = { + flatOffset: flatOffset, + // Point to the original dataId, which is used to do ref counting. + origDataId: xTexData.slice && xTexData.slice.origDataId || x.dataId + }; + // Increase the ref count for that data bucket. + var refCount = this.dataRefCount.get(newTexData.slice.origDataId) || 1; + this.dataRefCount.set(newTexData.slice.origDataId, refCount + 1); + return t; + }; + MathBackendWebGL.prototype.stridedSlice = function (x, begin, end, strides) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.stridedSlice(x, begin, end, strides); + } + var outShape = computeOutShape$2(begin, end, strides); + if (outShape.some(function (axis) { return axis === 0; })) { + return tensor([], outShape); + } + var program = new StridedSliceProgram(begin, strides, outShape); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.reverse = function (x, axis) { + var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + new ReversePackedProgram(x.shape, axis) : + new ReverseProgram(x.shape, axis); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.concat = function (tensors, axis) { + if (tensors[0].dtype === 'complex64') { + var reals = tensors.map(function (t) { return real(t); }); + var imags = tensors.map(function (t) { return imag(t); }); + return complex(this.concat(reals, axis), this.concat(imags, axis)); + } + if (this.shouldExecuteOnCPU(tensors)) { + return this.cpuBackend.concat(tensors, axis); + } + if (tensors.length === 1) { + return tensors[0]; + } + if (tensors.length > env().getNumber('WEBGL_MAX_TEXTURES_IN_SHADER')) { + var midIndex = Math.floor(tensors.length / 2); + var leftSide = this.concat(tensors.slice(0, midIndex), axis); + var rightSide = this.concat(tensors.slice(midIndex), axis); + return this.concat([leftSide, rightSide], axis); + } + if (env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') && tensors[0].rank > 1) { + var program_1 = new ConcatPackedProgram(tensors.map(function (t) { return t.shape; }), axis); + return this.compileAndRun(program_1, tensors); + } + // Any concat of n-dimensional tensors across any axis can be reduced to + // a concatenation of two-dimensional tensors across the axis 1 by first + // partitioning the axes of the original tensors into those less than the + // axis to be concatenated and the rest. Then reshape the tensors + // into a two-dimensional tensor by collapsing these two sets of axes and + // concatenate the resulting matrices across the axis 1, finally reshaping + // the result to have the proper shape. + var outShape = computeOutShape(tensors.map(function (t) { return t.shape; }), axis); + var tensors2D = tensors.map(function (t) { return t.as2D(-1, sizeFromShape(t.shape.slice(axis))); }); + var program = new ConcatProgram(tensors2D.map(function (t) { return t.shape; })); + var res = this.compileAndRun(program, tensors2D); + return res.reshape(outShape); + }; + MathBackendWebGL.prototype.neg = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.neg(x); + } + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, NEG, x.dtype); + } + var program = new UnaryOpProgram(x.shape, NEG); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.batchMatMul = function (a, b, transposeA, transposeB) { + var outerShapeA = transposeA ? a.shape[2] : a.shape[1]; + var outerShapeB = transposeB ? b.shape[1] : b.shape[2]; + var sharedDim = transposeA ? a.shape[1] : a.shape[2]; + var _a = a.shape, batch = _a[0]; + // Since the matrices are vectors, it is faster to call mul().sum() + // because sum() is O(sqrt(N)) due to divide-and-conquer. + if ((outerShapeA === 1 || outerShapeB === 1) && + sharedDim > MATMUL_SHARED_DIM_THRESHOLD) { + if (transposeA) { + a = a.transpose([0, 2, 1]); + } + if (transposeB) { + b = b.transpose([0, 2, 1]); + } + var a3D = outerShapeB === 1 ? a : a.as3D(batch, sharedDim, 1); + var axis = outerShapeB === 1 ? 2 : 1; + var b3D = outerShapeB === 1 ? b.as3D(batch, 1, sharedDim) : b; + return this.multiply(a3D, b3D).sum(axis, true /* keepDims */); + } + var dtype = upcastType(a.dtype, b.dtype); + var program = new MatMulPackedProgram(a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB); + return this.compileAndRun(program, [a, b], dtype); + }; + MathBackendWebGL.prototype.fusedBatchMatMul = function (_a) { + var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var outerShapeA = transposeA ? a.shape[2] : a.shape[1]; + var outerShapeB = transposeB ? b.shape[1] : b.shape[2]; + var _b = a.shape, batch = _b[0]; + var dtype = upcastType(a.dtype, b.dtype); + var hasBias = bias != null; + var hasPreluActivationWeights = preluActivationWeights != null; + var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; + var program = new MatMulPackedProgram(a.shape, [batch, outerShapeA, outerShapeB], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights); + var inputs = [a, b]; + if (bias) { + inputs.push(bias); + } + if (preluActivationWeights) { + inputs.push(preluActivationWeights); + } + return this.compileAndRun(program, inputs, dtype); + }; + MathBackendWebGL.prototype.multiply = function (a, b) { + if (a.dtype === 'complex64') { + var aData = this.texData.get(a.dataId); + var bData = this.texData.get(b.dataId); + var realProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.REAL, a.shape, b.shape); + var imagProgram = new BinaryOpComplexProgram(COMPLEX_MULTIPLY.IMAG, a.shape, b.shape); + var inputs = [ + this.makeComplexComponentTensorInfo(a, aData.complexTensors.real), + this.makeComplexComponentTensorInfo(a, aData.complexTensors.imag), + this.makeComplexComponentTensorInfo(b, bData.complexTensors.real), + this.makeComplexComponentTensorInfo(b, bData.complexTensors.imag) + ]; + var real_1 = this.compileAndRun(realProgram, inputs); + var imag_1 = this.compileAndRun(imagProgram, inputs); + var complex_1 = this.complex(real_1, imag_1); + real_1.dispose(); + imag_1.dispose(); + return complex_1; + } + if (this.shouldExecuteOnCPU([a, b])) { + return this.cpuBackend.multiply(a, b); + } + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, MUL, a.dtype); + } + var program = new BinaryOpProgram(MUL, a.shape, b.shape); + return this.compileAndRun(program, [a, b], a.dtype); + }; + MathBackendWebGL.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) { + var inputs = [x, mean, variance]; + var offsetShape = null; + if (offset != null) { + offsetShape = offset.shape; + inputs.push(offset); + } + var scaleShape = null; + if (scale != null) { + scaleShape = scale.shape; + inputs.push(scale); + } + if (env().getBool('WEBGL_PACK_NORMALIZATION')) { + var batchNormPackedProgram = new BatchNormPackedProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); + return this.compileAndRun(batchNormPackedProgram, inputs); + } + var batchNormProgram = new BatchNormProgram(x.shape, mean.shape, variance.shape, offsetShape, scaleShape, varianceEpsilon); + return this.compileAndRun(batchNormProgram, inputs); + }; + MathBackendWebGL.prototype.localResponseNormalization4D = function (x, radius, bias, alpha, beta) { + var program = env().getBool('WEBGL_PACK_NORMALIZATION') ? + new LRNPackedProgram(x.shape, radius, bias, alpha, beta) : + new LRNProgram(x.shape, radius, bias, alpha, beta); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.LRNGrad = function (dy, inputImage, outputImage, depthRadius, bias, alpha, beta) { + var program = new LRNGradProgram(inputImage.shape, depthRadius, bias, alpha, beta); + return this.compileAndRun(program, [inputImage, outputImage, dy]); + }; + MathBackendWebGL.prototype.tile = function (x, reps) { + if (x.dtype === 'string') { + var data = this.readSync(x.dataId); + var decodedData = data.map(function (d) { return decodeString(d); }); + var buf = buffer(x.shape, x.dtype, decodedData); + return tile$1(buf, reps); + } + var program = new TileProgram(x.shape, reps); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.pad = function (x, paddings, constantValue) { + var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + new PadPackedProgram(x.shape, paddings, constantValue) : + new PadProgram(x.shape, paddings, constantValue); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.transpose = function (x, perm) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.transpose(x, perm); + } + var program = env().getBool('WEBGL_PACK_ARRAY_OPERATIONS') ? + new TransposePackedProgram(x.shape, perm) : + new TransposeProgram(x.shape, perm); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.gather = function (x, indices, axis) { + if (this.shouldExecuteOnCPU([x, indices])) { + return this.cpuBackend.gather(x, indices, axis); + } + var program = new GatherProgram(x.shape, indices.size, axis); + return this.compileAndRun(program, [x, indices]); + }; + MathBackendWebGL.prototype.batchToSpaceND = function (x, blockShape, crops) { + assert(x.rank <= 4, function () { return 'batchToSpaceND for rank > 4 with a WebGL backend not ' + + 'implemented yet'; }); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + var reshaped = getReshaped(x.shape, blockShape, prod); + var permuted = getPermuted(reshaped.length, blockShape.length); + var reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod); + var sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length); + var sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length); + return x.reshape(reshaped) + .transpose(permuted) + .reshape(reshapedPermuted) + .slice(sliceBeginCoords, sliceSize); + }; + MathBackendWebGL.prototype.spaceToBatchND = function (x, blockShape, paddings) { + assert(x.rank <= 4, function () { return 'spaceToBatchND for rank > 4 with a WebGL backend not ' + + 'implemented yet'; }); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + var completePaddings = [[0, 0]]; + completePaddings.push.apply(completePaddings, paddings); + for (var i = 1 + blockShape.length; i < x.shape.length; ++i) { + completePaddings.push([0, 0]); + } + var paddedX = x.pad(completePaddings); + var reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false); + var permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false); + var flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false); + return paddedX.reshape(reshapedPaddedShape) + .transpose(permutedReshapedPaddedPermutation) + .reshape(flattenShape); + }; + MathBackendWebGL.prototype.reduce = function (x, reduceType, dtype) { + var batchSize = x.shape[0]; + var inSize = x.shape[1]; + var windowSize = computeOptimalWindowSize(inSize); + var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize }; + var program = new ReduceProgram(reduceInfo, reduceType); + var output = this.compileAndRun(program, [x], dtype); + // No need to run another GPGPU program. + if (output.shape[1] === 1) { + return output; + } + return this.reduce(output, reduceType, dtype); + }; + MathBackendWebGL.prototype.argReduce = function (x, reduceType, bestIndicesA) { + if (bestIndicesA === void 0) { bestIndicesA = null; } + var batchSize = x.shape[0]; + var inSize = x.shape[1]; + if (bestIndicesA != null) { + batchSize = bestIndicesA.shape[0]; + inSize = bestIndicesA.shape[1]; + } + var windowSize = computeOptimalWindowSize(inSize); + var reduceInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize }; + var program = new ArgMinMaxProgram(reduceInfo, reduceType, bestIndicesA == null); + var inputs = [x]; + if (bestIndicesA != null) { + inputs.push(bestIndicesA); + } + var output = this.compileAndRun(program, inputs, 'int32'); + // No need to run another GPGPU program. + if (output.shape[1] === 1) { + return output; + } + return this.argReduce(x, reduceType, output); + }; + MathBackendWebGL.prototype.argReducePacked = function (x, reduceType, bestIndicesA) { + if (bestIndicesA === void 0) { bestIndicesA = null; } + var inShape = bestIndicesA != null ? bestIndicesA.shape : x.shape; + var inSize = inShape[inShape.length - 1]; + var windowSize = computeOptimalWindowSize(inSize); + var program = new ArgMinMaxPackedProgram(inShape, windowSize, reduceType, bestIndicesA == null); + var inputs = bestIndicesA == null ? [x] : [x, bestIndicesA]; + var output = this.compileAndRun(program, inputs, 'int32'); + if (output.rank === x.rank) { + return this.argReducePacked(x, reduceType, output); + } + return output; + }; + MathBackendWebGL.prototype.sum = function (x, axes) { + assertAxesAreInnerMostDims('sum', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var inSize = sizeFromShape(reduceShape); + var a2D = x.as2D(-1, inSize); + var outputDType = sumOutType(x.dtype); + return this.reduce(a2D, 'sum', outputDType).reshape(outShape); + }; + MathBackendWebGL.prototype.prod = function (x, axes) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.prod(x, axes); + } + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var inSize = sizeFromShape(reduceShape); + var a2D = x.as2D(-1, inSize); + var outputDType = sumOutType(x.dtype); + return this.reduce(a2D, 'prod', outputDType).reshape(outShape); + }; + MathBackendWebGL.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) { + var axis = 0; + var permutation = getAxesPermutation([axis], x.rank); + var permutedX = x; + if (permutation != null) { + permutedX = x.transpose(permutation); + axis = getInnerMostAxes(1, x.rank)[0]; + } + var outShape = computeOutShape$1(permutedX.shape, axis, numSegments); + var inSize = sizeFromShape([permutedX.shape[axis]]); + var a2D = permutedX.as2D(-1, inSize); + var outputDType = sumOutType(x.dtype); + var result = this.segOpCompute(a2D, 'unsortedSegmentSum', segmentIds, outputDType, numSegments) + .reshape(outShape); + if (permutation != null) { + result = result.transpose(getUndoAxesPermutation(permutation)); + } + return result; + }; + MathBackendWebGL.prototype.segOpCompute = function (x, segOpType, segmentIds, dtype, numSegments) { + var batchSize = x.shape[0]; + var inSize = x.shape[1]; + var windowSize = segOpComputeOptimalWindowSize(inSize, numSegments); + var segOpInfo = { windowSize: windowSize, inSize: inSize, batchSize: batchSize, numSegments: numSegments }; + var program = new SegmentOpProgram(segOpInfo, segOpType); + var output = this.compileAndRun(program, [x, segmentIds], dtype); + // No need to run another GPGPU program. + if (output.shape[1] === numSegments) { + return output; + } + segmentIds = range(0, numSegments).tile([inSize / windowSize]); + return this.segOpCompute(output, segOpType, segmentIds, dtype, numSegments); + }; + MathBackendWebGL.prototype.argMinMaxReduce = function (x, axis, reduceType) { + var axes = [axis]; + assertAxesAreInnerMostDims('arg' + reduceType.charAt(0).toUpperCase() + reduceType.slice(1), axes, x.rank); + if (!env().getBool('WEBGL_PACK_REDUCE') || x.rank <= 2) { + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var inSize = sizeFromShape(reduceShape); + var a2D = x.as2D(-1, inSize); + return this.argReduce(a2D, reduceType).reshape(outShape); + } + return this.argReducePacked(x, reduceType); + }; + MathBackendWebGL.prototype.argMin = function (x, axis) { + return this.argMinMaxReduce(x, axis, 'min'); + }; + MathBackendWebGL.prototype.argMax = function (x, axis) { + return this.argMinMaxReduce(x, axis, 'max'); + }; + MathBackendWebGL.prototype.cumsum = function (x, axis, exclusive, reverse) { + if (axis !== x.rank - 1) { + throw new Error("WebGL cumsum shader expects an inner-most axis=" + (x.rank - 1) + " " + + ("but got axis=" + axis)); + } + var program = new CumSumProgram(x.shape, exclusive, reverse); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.equal = function (a, b) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, EQUAL$1, 'bool'); + } + var program = new BinaryOpProgram(EQUAL, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.notEqual = function (a, b) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, NOT_EQUAL$1, 'bool'); + } + var program = new BinaryOpProgram(NOT_EQUAL, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.less = function (a, b) { + if (this.shouldExecuteOnCPU([a, b])) { + return this.cpuBackend.less(a, b); + } + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, LESS$1, 'bool'); + } + var program = new BinaryOpProgram(LESS, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.lessEqual = function (a, b) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, LESS_EQUAL$1, 'bool'); + } + var program = new BinaryOpProgram(LESS_EQUAL, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.greater = function (a, b) { + if (this.shouldExecuteOnCPU([a, b])) { + return this.cpuBackend.greater(a, b); + } + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, GREATER$1, 'bool'); + } + var program = new BinaryOpProgram(GREATER, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.greaterEqual = function (a, b) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, GREATER_EQUAL$1, 'bool'); + } + var program = new BinaryOpProgram(GREATER_EQUAL, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.logicalNot = function (x) { + var program = new UnaryOpProgram(x.shape, LOGICAL_NOT); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.logicalAnd = function (a, b) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, LOGICAL_AND$1, 'bool'); + } + var program = new BinaryOpProgram(LOGICAL_AND, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.logicalOr = function (a, b) { + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, LOGICAL_OR$1, 'bool'); + } + var program = new BinaryOpProgram(LOGICAL_OR, a.shape, b.shape); + return this.compileAndRun(program, [a, b], 'bool'); + }; + MathBackendWebGL.prototype.select = function (condition, a, b) { + var program = new SelectProgram(condition.rank, a.shape, a.rank); + return this.compileAndRun(program, [condition, a, b], upcastType(a.dtype, b.dtype)); + }; + MathBackendWebGL.prototype.where = function (condition) { + warn('tf.where() in webgl locks the UI thread. ' + + 'Call tf.whereAsync() instead'); + var condVals = condition.dataSync(); + return whereImpl(condition.shape, condVals); + }; + MathBackendWebGL.prototype.topk = function (x, k, sorted) { + var xVals = x.dataSync(); + return topkImpl(xVals, x.shape, x.dtype, k, sorted); + }; + MathBackendWebGL.prototype.min = function (x, axes) { + assertAxesAreInnerMostDims('min', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var inSize = sizeFromShape(reduceShape); + var a2D = x.as2D(-1, inSize); + return this.reduce(a2D, 'min', a2D.dtype).reshape(outShape); + }; + MathBackendWebGL.prototype.minimum = function (a, b) { + if (this.shouldExecuteOnCPU([a, b])) { + return this.cpuBackend.minimum(a, b); + } + var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + new BinaryOpPackedProgram(MIN$1, a.shape, b.shape) : + new BinaryOpProgram(MIN, a.shape, b.shape); + return this.compileAndRun(program, [a, b]); + }; + MathBackendWebGL.prototype.mod = function (a, b) { + var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + new BinaryOpPackedProgram(MOD$1, a.shape, b.shape) : + new BinaryOpProgram(MOD, a.shape, b.shape); + return this.compileAndRun(program, [a, b]); + }; + MathBackendWebGL.prototype.max = function (x, axes) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.max(x, axes); + } + assertAxesAreInnerMostDims('max', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var inSize = sizeFromShape(reduceShape); + var a2D = x.as2D(-1, inSize); + return this.reduce(a2D, 'max', a2D.dtype).reshape(outShape); + }; + MathBackendWebGL.prototype.maximum = function (a, b) { + if (this.shouldExecuteOnCPU([a, b])) { + return this.cpuBackend.maximum(a, b); + } + var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + new BinaryOpPackedProgram(MAX$1, a.shape, b.shape) : + new BinaryOpProgram(MAX, a.shape, b.shape); + return this.compileAndRun(program, [a, b]); + }; + MathBackendWebGL.prototype.all = function (x, axes) { + assertAxesAreInnerMostDims('all', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var inSize = sizeFromShape(reduceShape); + var a2D = x.as2D(-1, inSize); + return this.reduce(a2D, 'all', a2D.dtype).reshape(outShape); + }; + MathBackendWebGL.prototype.any = function (x, axes) { + assertAxesAreInnerMostDims('any', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var inSize = sizeFromShape(reduceShape); + var a2D = x.as2D(-1, inSize); + return this.reduce(a2D, 'any', a2D.dtype).reshape(outShape); + }; + MathBackendWebGL.prototype.squaredDifference = function (a, b) { + var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + new BinaryOpPackedProgram(SQUARED_DIFFERENCE, a.shape, b.shape) : + new BinaryOpProgram(SQUARED_DIFFERENCE, a.shape, b.shape); + return this.compileAndRun(program, [a, b]); + }; + MathBackendWebGL.prototype.realDivide = function (a, b) { + var op = DIV; + var outputDtype = 'float32'; + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + var checkOutOfBounds = true; + return this.packedBinaryOp(a, b, DIV$1, outputDtype, checkOutOfBounds); + } + var program = new BinaryOpProgram(op, a.shape, b.shape); + return this.compileAndRun(program, [a, b], outputDtype); + }; + MathBackendWebGL.prototype.floorDiv = function (a, b) { + var op = INT_DIV; + var outputDtype = 'int32'; + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, INT_DIV$1, outputDtype); + } + var program = new BinaryOpProgram(op, a.shape, b.shape); + return this.compileAndRun(program, [a, b], outputDtype); + }; + MathBackendWebGL.prototype.add = function (a, b) { + if (a.dtype === 'complex64' && b.dtype === 'complex64') { + return this.complexSeparableBinaryOp(a, b, ADD); + } + if (this.shouldExecuteOnCPU([a, b])) { + return this.cpuBackend.add(a, b); + } + var dtype = upcastType(a.dtype, b.dtype); + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, ADD, dtype); + } + var program = new BinaryOpProgram(ADD, a.shape, b.shape); + return this.compileAndRun(program, [a, b], dtype); + }; + MathBackendWebGL.prototype.packedUnaryOp = function (x, op, dtype) { + var program = new UnaryOpPackedProgram(x.shape, op); + return this.compileAndRun(program, [x], dtype); + }; + MathBackendWebGL.prototype.packedBinaryOp = function (a, b, op, dtype, checkOutOfBounds) { + if (checkOutOfBounds === void 0) { checkOutOfBounds = false; } + var program = new BinaryOpPackedProgram(op, a.shape, b.shape, checkOutOfBounds); + return this.compileAndRun(program, [a, b], dtype); + }; + /** + * Computes a complex binary operation that can be decomposed into a simple + * binary operation on both the real and imagary parts. + */ + MathBackendWebGL.prototype.complexSeparableBinaryOp = function (a, b, op) { + var _this = this; + var aData = this.texData.get(a.dataId); + var bData = this.texData.get(b.dataId); + var _a = [ + [aData.complexTensors.real, bData.complexTensors.real], + [aData.complexTensors.imag, bData.complexTensors.imag] + ].map(function (complexParts) { + var aPart = complexParts[0], bPart = complexParts[1]; + var aHandle = _this.makeComplexComponentTensorInfo(a, aPart); + var bHandle = _this.makeComplexComponentTensorInfo(b, bPart); + var program = new BinaryOpProgram(op, a.shape, b.shape); + return _this.compileAndRun(program, [aHandle, bHandle], upcastType(aPart.dtype, bPart.dtype)); + }), real = _a[0], imag = _a[1]; + var complex = this.complex(real, imag); + real.dispose(); + imag.dispose(); + return complex; + }; + // Returns a TensorInfo with the complex shape and the dataId of the + // underlying part. We need to do this because a reshaped complex tensor is + // not reflected in its parts. + MathBackendWebGL.prototype.makeComplexComponentTensorInfo = function (complexTensor, complexPart) { + return { + dataId: complexPart.dataId, + dtype: complexPart.dtype, + shape: complexTensor.shape + }; + }; + MathBackendWebGL.prototype.addN = function (tensors) { + if (tensors.length === 1) { + return tensors[0]; + } + // Limit the number of uploaded textures for optimization. + if (tensors.length > env().get('WEBGL_MAX_TEXTURES_IN_SHADER')) { + var midIndex = Math.floor(tensors.length / 2); + var leftSide = this.addN(tensors.slice(0, midIndex)); + var rightSide = this.addN(tensors.slice(midIndex)); + return this.addN([leftSide, rightSide]); + } + var dtype = tensors.map(function (t) { return t.dtype; }).reduce(function (d1, d2) { return upcastType(d1, d2); }); + var shapes = tensors.map(function (t) { return t.shape; }); + // We can make sure shapes are identical in op level. + var usePackedOp = env().getBool('WEBGL_PACK'); + var program = usePackedOp ? + new AddNPackedProgram(tensors[0].shape, shapes) : + new AddNProgram(tensors[0].shape, shapes); + return this.compileAndRun(program, tensors, dtype); + }; + MathBackendWebGL.prototype.subtract = function (a, b) { + if (a.dtype === 'complex64' && b.dtype === 'complex64') { + return this.complexSeparableBinaryOp(a, b, SUB); + } + if (this.shouldExecuteOnCPU([a, b])) { + return this.cpuBackend.subtract(a, b); + } + var dtype = upcastType(a.dtype, b.dtype); + if (env().getBool('WEBGL_PACK_BINARY_OPERATIONS')) { + return this.packedBinaryOp(a, b, SUB, a.dtype); + } + var program = new BinaryOpProgram(SUB, a.shape, b.shape); + return this.compileAndRun(program, [a, b], dtype); + }; + MathBackendWebGL.prototype.pow = function (a, b) { + var usePackedOp = env().getBool('WEBGL_PACK_BINARY_OPERATIONS'); + var program = usePackedOp ? + new BinaryOpPackedProgram(POW$1, a.shape, b.shape) : + new BinaryOpProgram(POW, a.shape, b.shape); + var dtype = upcastType(a.dtype, b.dtype); + return this.compileAndRun(program, [a, b], dtype); + }; + MathBackendWebGL.prototype.ceil = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.ceil(x); + } + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, CEIL, x.dtype); + } + var program = new UnaryOpProgram(x.shape, CEIL); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.floor = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.floor(x); + } + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, FLOOR, x.dtype); + } + var program = new UnaryOpProgram(x.shape, FLOOR); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.sign = function (x) { + var program = new UnaryOpProgram(x.shape, SIGN); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.isNaN = function (x) { + var program = new UnaryOpProgram(x.shape, IS_NAN); + return this.compileAndRun(program, [x], 'bool'); + }; + MathBackendWebGL.prototype.isInf = function (x) { + var program = new UnaryOpProgram(x.shape, IS_INF); + return this.compileAndRun(program, [x], 'bool'); + }; + MathBackendWebGL.prototype.isFinite = function (x) { + var program = new UnaryOpProgram(x.shape, IS_FINITE); + return this.compileAndRun(program, [x], 'bool'); + }; + MathBackendWebGL.prototype.round = function (x) { + var program = new UnaryOpProgram(x.shape, ROUND); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.exp = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.exp(x); + } + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, EXP, x.dtype); + } + var program = new UnaryOpProgram(x.shape, EXP); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.expm1 = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.expm1(x); + } + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, EXPM1, x.dtype); + } + var program = new UnaryOpProgram(x.shape, EXPM1); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.log = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.log(x); + } + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, LOG$1, x.dtype); + } + var program = new UnaryOpProgram(x.shape, LOG); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.log1p = function (x) { + var program = new UnaryOpProgram(x.shape, LOG1P); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.sqrt = function (x) { + var program = new UnaryOpProgram(x.shape, SQRT); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.rsqrt = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.rsqrt(x); + } + var program = new UnaryOpProgram(x.shape, RSQRT); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.reciprocal = function (x) { + var program = new UnaryOpProgram(x.shape, RECIPROCAL); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.relu = function (x) { + var program; + if (env().getBool('WEBGL_PACK')) { + program = new UnaryOpPackedProgram(x.shape, RELU$1); + } + else { + program = new UnaryOpProgram(x.shape, RELU); + } + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.relu6 = function (x) { + var program; + if (env().getBool('WEBGL_PACK')) { + program = new UnaryOpPackedProgram(x.shape, RELU6$1); + } + else { + program = new UnaryOpProgram(x.shape, RELU6); + } + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.prelu = function (x, alpha) { + var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + new BinaryOpPackedProgram(PRELU$1, x.shape, alpha.shape) : + new BinaryOpProgram(PRELU, x.shape, alpha.shape); + return this.compileAndRun(program, [x, alpha]); + }; + MathBackendWebGL.prototype.elu = function (x) { + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, ELU$1, x.dtype); + } + var program = new UnaryOpProgram(x.shape, ELU); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.eluDer = function (dy, y) { + var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + new BinaryOpPackedProgram(ELU_DER$1, dy.shape, y.shape) : + new BinaryOpProgram(ELU_DER, dy.shape, y.shape); + return this.compileAndRun(program, [dy, y]); + }; + MathBackendWebGL.prototype.selu = function (x) { + var program = new UnaryOpProgram(x.shape, SELU); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.int = function (x) { + var program = new UnaryOpProgram(x.shape, TO_INT); + return this.compileAndRun(program, [x], 'int32'); + }; + MathBackendWebGL.prototype.clip = function (x, min, max) { + var program; + if (env().getBool('WEBGL_PACK_CLIP')) { + program = new ClipPackedProgram(x.shape); + } + else { + program = new ClipProgram(x.shape); + } + var customSetup = program.getCustomSetupFunc(min, max); + return this.compileAndRun(program, [x], null, customSetup); + }; + MathBackendWebGL.prototype.abs = function (x) { + if (this.shouldExecuteOnCPU([x])) { + return this.cpuBackend.abs(x); + } + if (env().getBool('WEBGL_PACK_UNARY_OPERATIONS')) { + return this.packedUnaryOp(x, ABS, x.dtype); + } + var program = new UnaryOpProgram(x.shape, ABS); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.complexAbs = function (x) { + var xData = this.texData.get(x.dataId); + var program = new ComplexAbsProgram(x.shape); + var inputs = [ + this.makeComplexComponentTensorInfo(x, xData.complexTensors.real), + this.makeComplexComponentTensorInfo(x, xData.complexTensors.imag), + ]; + return this.compileAndRun(program, inputs); + }; + MathBackendWebGL.prototype.sigmoid = function (x) { + var program = new UnaryOpProgram(x.shape, SIGMOID); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.softplus = function (x) { + var program = new UnaryOpProgram(x.shape, SOFTPLUS); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.sin = function (x) { + var program = new UnaryOpProgram(x.shape, SIN); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.cos = function (x) { + var program = new UnaryOpProgram(x.shape, COS); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.tan = function (x) { + var program = new UnaryOpProgram(x.shape, TAN); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.asin = function (x) { + var program = new UnaryOpProgram(x.shape, ASIN); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.acos = function (x) { + var program = new UnaryOpProgram(x.shape, ACOS); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.atan = function (x) { + var program = new UnaryOpProgram(x.shape, ATAN); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.atan2 = function (a, b) { + var program = env().getBool('WEBGL_PACK_BINARY_OPERATIONS') ? + new BinaryOpPackedProgram(ATAN2$1, a.shape, b.shape) : + new BinaryOpProgram(ATAN2, a.shape, b.shape); + return this.compileAndRun(program, [a, b]); + }; + MathBackendWebGL.prototype.sinh = function (x) { + var program = new UnaryOpProgram(x.shape, SINH); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.cosh = function (x) { + var program = new UnaryOpProgram(x.shape, COSH); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.tanh = function (x) { + var program = new UnaryOpProgram(x.shape, TANH); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.asinh = function (x) { + var program = new UnaryOpProgram(x.shape, ASINH); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.acosh = function (x) { + var program = new UnaryOpProgram(x.shape, ACOSH); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.atanh = function (x) { + var program = new UnaryOpProgram(x.shape, ATANH); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.erf = function (x) { + var program = new UnaryOpProgram(x.shape, ERF); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.step = function (x, alpha) { + var program = new UnaryOpProgram(x.shape, STEP(alpha)); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.conv2dByMatMul = function (x, filter, convInfo, bias, activation, preluActivationWeights) { + // Reshapes conv2D input to 2D tensors, uses matMul and then reshape the + // result from 2D to 4D. + var xShape = x.shape; + var xTexData = this.texData.get(x.dataId); + var sharedMatMulDim = convInfo.inChannels; + var outerShapeX = xShape[0] * xShape[1] * xShape[2]; + var outerShapeFilter = convInfo.outChannels; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + var transposeA = false; + var transposeB = false; + // TODO: Once reduction ops are packed, batchMatMul will always be packed + // and we can remove this condition. + var batchMatMulWillBeUnpacked = (outerShapeX === 1 || outerShapeFilter === 1) && + sharedMatMulDim > MATMUL_SHARED_DIM_THRESHOLD; + var reshapeWillBeExpensive = xShape[2] % 2 !== 0 && !!xTexData.isPacked; + if (batchMatMulWillBeUnpacked || !env().getBool('WEBGL_LAZILY_UNPACK') || + !env().getBool('WEBGL_PACK_BINARY_OPERATIONS') || + !reshapeWillBeExpensive) { + var targetShape_1 = isChannelsLast ? xShape[0] * xShape[1] * xShape[2] : + xShape[0] * xShape[2] * xShape[3]; + var xReshaped_1 = this.reshape(x, [1, targetShape_1, convInfo.inChannels]); + var filterReshaped_1 = this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]); + return this.reshape(this.fusedBatchMatMul({ + a: xReshaped_1, + b: filterReshaped_1, + transposeA: transposeA, + transposeB: transposeB, + bias: bias, + activation: activation, + preluActivationWeights: preluActivationWeights + }), convInfo.outShape); + } + // Following optimization is specific to packed |x| with odd row count + // (For example, in channelLast mode, 'row count' refers to x.shape[2]): + // we avoid expensive packed 2x2 reshape by padding row count to next, + // even number. When x.shape[2] is odd, the result of packed batchMatMul is + // the same (has the same texture layout and and values in the texture) as + // it is for even x.shape[2] + 1. We make the odd-rows tensor to look like + // even-rows tensor before the operation and, after the batchMatMul, + // fix the even-rows result to have odd number of rows. + var targetShape = isChannelsLast ? + xShape[0] * xShape[1] * (xShape[2] + 1) : + xShape[0] * xShape[2] * (xShape[3] + 1); + var xReshaped = { + dataId: x.dataId, + shape: [1, targetShape, convInfo.inChannels], + dtype: x.dtype + }; + // xTexData.shape gets referenced from GPGPUBinary.inShapeInfos. + // Decrementing row count, after batchMatMul->...->compileProgram leads to + // invalid row count within the reference in GPGPUBinary.inShapeInfos. + // Alternative fix would be to provide a copy to GPGPUBinary.inShapeInfos + // in compileProgram method, but that would affect compilation of all + // programs - instead, provide a copy here, with even row count, before + // calling batchMatMul->...->compileProgram and after that, the original + // xTexData.shape is restored. + var originalXTexDataShape = xTexData.shape; + xTexData.shape = xTexData.shape.slice(); + xTexData.shape[xTexData.shape.length - 2]++; + assert(isReshapeFree(xTexData.shape, xReshaped.shape), function () { return "packed reshape " + xTexData.shape + " to " + xReshaped.shape + " isn't free"; }); + var filterReshaped = this.reshape(filter, [1, convInfo.inChannels, convInfo.outChannels]); + var pointwiseConv = this.fusedBatchMatMul({ + a: xReshaped, + b: filterReshaped, + transposeA: transposeA, + transposeB: transposeB, + bias: bias, + activation: activation, + preluActivationWeights: preluActivationWeights + }); + var pointwiseConvTexData = this.texData.get(pointwiseConv.dataId); + assert(pointwiseConvTexData.isPacked, function () { return 'batchMatMul result is expected to be packed'; }); + // Restore the input shape to original. + xTexData.shape = originalXTexDataShape; + // Set the output shape - there is no need for expensive reshape as data + // layout is already correct. + pointwiseConvTexData.shape = convInfo.outShape; + return ENGINE.makeTensorFromDataId(pointwiseConv.dataId, convInfo.outShape, pointwiseConv.dtype); + }; + MathBackendWebGL.prototype.conv2dWithIm2Row = function (x, filter, convInfo, bias, activation, preluActivationWeights) { + // Rearranges conv2d input so each block to be convolved over forms the + // column of a new matrix with shape [filterWidth * filterHeight * + // inChannels, outHeight * outWidth]. The filter is also rearranged so each + // output channel forms a row of a new matrix with shape [outChannels, + // filterWidth * filterHeight * inChannels]. The convolution is then + // computed by multiplying these matrices and reshaping the result. + var filterWidth = convInfo.filterWidth, filterHeight = convInfo.filterHeight, inChannels = convInfo.inChannels, outWidth = convInfo.outWidth, outHeight = convInfo.outHeight, dataFormat = convInfo.dataFormat; + var isChannelsLast = dataFormat === 'channelsLast'; + var sharedDim = filterWidth * filterHeight * inChannels; + var numCols = outHeight * outWidth; + var x2ColShape = [sharedDim, numCols]; + var transposeA = true; + var transposeB = false; + var xSqueezed = x.squeeze([0]); + var w2Row = filter.reshape([1, sharedDim, -1]); + var im2ColProgram = new Im2ColPackedProgram(x2ColShape, xSqueezed.shape, convInfo); + var im2Col = this.compileAndRun(im2ColProgram, [xSqueezed]).reshape([ + 1, x2ColShape[0], x2ColShape[1] + ]); + var hasBias = bias != null; + var hasPreluActivationWeights = preluActivationWeights != null; + var fusedActivation = activation ? mapActivationToShaderProgram(activation, true) : null; + var matmulProgram = new MatMulPackedProgram(im2Col.shape, [1, numCols, convInfo.outChannels], transposeA, transposeB, hasBias, fusedActivation, hasPreluActivationWeights); + var inputs = [im2Col, w2Row]; + if (bias) { + inputs.push(bias); + } + if (hasPreluActivationWeights) { + inputs.push(preluActivationWeights); + } + var product = this.compileAndRun(matmulProgram, inputs); + if (isChannelsLast) { + return product.reshape([1, outHeight, outWidth, convInfo.outChannels]); + } + else { + return product.reshape([1, convInfo.outChannels, outHeight, outWidth]); + } + }; + MathBackendWebGL.prototype.fusedConv2d = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && + convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && + convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && + (convInfo.padInfo.type === 'SAME' || + convInfo.padInfo.type === 'VALID')) { + return this.conv2dByMatMul(input, filter, convInfo, bias, activation, preluActivationWeights); + } + if (env().getBool('WEBGL_CONV_IM2COL') && input.shape[0] === 1) { + return this.conv2dWithIm2Row(input, filter, convInfo, bias, activation, preluActivationWeights); + } + var hasBias = bias != null; + var hasPreluActivationWeights = preluActivationWeights != null; + var fusedActivation = activation ? mapActivationToShaderProgram(activation, false) : null; + var program = new Conv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights); + var inputs = [input, filter]; + if (bias) { + inputs.push(bias); + } + if (preluActivationWeights) { + inputs.push(preluActivationWeights); + } + return this.compileAndRun(program, inputs); + }; + MathBackendWebGL.prototype.conv2d = function (x, filter, convInfo) { + if (convInfo.filterHeight === 1 && convInfo.filterWidth === 1 && + convInfo.dilationHeight === 1 && convInfo.dilationWidth === 1 && + convInfo.strideHeight === 1 && convInfo.strideWidth === 1 && + (convInfo.padInfo.type === 'SAME' || + convInfo.padInfo.type === 'VALID')) { + return this.conv2dByMatMul(x, filter, convInfo); + } + if (env().getBool('WEBGL_CONV_IM2COL') && x.shape[0] === 1) { + return this.conv2dWithIm2Row(x, filter, convInfo); + } + var program = new Conv2DProgram(convInfo); + return this.compileAndRun(program, [x, filter]); + }; + MathBackendWebGL.prototype.conv2dDerInput = function (dy, filter, convInfo) { + var program = new Conv2DDerInputProgram(convInfo); + return this.compileAndRun(program, [dy, filter]); + }; + MathBackendWebGL.prototype.conv2dDerFilter = function (x, dy, convInfo) { + var program = new Conv2DDerFilterProgram(convInfo); + return this.compileAndRun(program, [x, dy]); + }; + MathBackendWebGL.prototype.fusedDepthwiseConv2D = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var shouldPackDepthwiseConv = env().getBool('WEBGL_PACK_DEPTHWISECONV') && + convInfo.strideWidth <= 2 && + convInfo.outChannels / convInfo.inChannels === 1; + var fusedActivation = activation ? + mapActivationToShaderProgram(activation, shouldPackDepthwiseConv) : + null; + var inputs = [input, filter]; + var hasBias = bias != null; + var hasPreluActivationWeights = preluActivationWeights != null; + if (hasBias) { + inputs.push(bias); + } + if (hasPreluActivationWeights) { + inputs.push(preluActivationWeights); + } + var program; + if (shouldPackDepthwiseConv) { + program = new DepthwiseConvPacked2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights); + return this.compileAndRun(program, inputs); + } + program = new DepthwiseConv2DProgram(convInfo, hasBias, fusedActivation, hasPreluActivationWeights); + return this.compileAndRun(program, inputs); + }; + MathBackendWebGL.prototype.depthwiseConv2D = function (x, filter, convInfo) { + var program; + if (env().getBool('WEBGL_PACK_DEPTHWISECONV') && + convInfo.strideWidth <= 2 && + convInfo.outChannels / convInfo.inChannels === 1) { + program = new DepthwiseConvPacked2DProgram(convInfo); + return this.compileAndRun(program, [x, filter]); + } + program = new DepthwiseConv2DProgram(convInfo); + return this.compileAndRun(program, [x, filter]); + }; + MathBackendWebGL.prototype.depthwiseConv2DDerInput = function (dy, filter, convInfo) { + var program = new DepthwiseConv2DDerInputProgram(convInfo); + return this.compileAndRun(program, [dy, filter]); + }; + MathBackendWebGL.prototype.depthwiseConv2DDerFilter = function (x, dy, convInfo) { + var program = new DepthwiseConv2DDerFilterProgram(convInfo); + return this.compileAndRun(program, [x, dy]); + }; + MathBackendWebGL.prototype.conv3d = function (x, filter, convInfo) { + var program = new Conv3DProgram(convInfo); + return this.compileAndRun(program, [x, filter]); + }; + MathBackendWebGL.prototype.conv3dDerInput = function (dy, filter, convInfo) { + var program = new Conv3DDerInputProgram(convInfo); + return this.compileAndRun(program, [dy, filter]); + }; + MathBackendWebGL.prototype.conv3dDerFilter = function (x, dy, convInfo) { + var program = new Conv3DDerFilterProgram(convInfo); + return this.compileAndRun(program, [x, dy]); + }; + MathBackendWebGL.prototype.maxPool = function (x, convInfo) { + var program = new Pool2DProgram(convInfo, 'max', false); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.avgPool = function (x, convInfo) { + var program = new Pool2DProgram(convInfo, 'avg', false); + return this.compileAndRun(program, [x], 'float32'); + }; + MathBackendWebGL.prototype.maxPoolBackprop = function (dy, x, y, convInfo) { + var getPositions = true; + var maxPoolPositionsProgram = new Pool2DProgram(convInfo, 'max', getPositions); + var maxPoolPositions = this.compileAndRun(maxPoolPositionsProgram, [x]); + var maxPoolBackPropProgram = new MaxPool2DBackpropProgram(convInfo); + var result = this.compileAndRun(maxPoolBackPropProgram, [dy, maxPoolPositions], x.dtype); + maxPoolPositions.dispose(); + return result; + }; + MathBackendWebGL.prototype.avgPoolBackprop = function (dy, x, convInfo) { + var avgPoolBackpropProgram = new AvgPool2DBackpropProgram(convInfo); + return this.compileAndRun(avgPoolBackpropProgram, [dy], x.dtype); + }; + MathBackendWebGL.prototype.cast = function (x, dtype) { + return castTensor(x, dtype, this); + }; + MathBackendWebGL.prototype.unstack = function (x, axis) { + var num = x.shape[axis]; + var outShape = new Array(x.rank - 1); + var outIndex = 0; + for (var i = 0; i < x.rank; i++) { + if (i !== axis) { + outShape[outIndex++] = x.shape[i]; + } + } + var begin = new Array(x.rank).fill(0); + var size = x.shape.slice(); + size[axis] = 1; + var res = new Array(num); + for (var i = 0; i < res.length; i++) { + begin[axis] = i; + res[i] = this.slice(x, begin, size).reshape(outShape); + } + return res; + }; + MathBackendWebGL.prototype.avgPool3d = function (x, convInfo) { + var program = new Pool3DProgram(convInfo, 'avg', false); + return this.compileAndRun(program, [x], 'float32'); + }; + MathBackendWebGL.prototype.avgPool3dBackprop = function (dy, x, convInfo) { + var avgPool3dBackpropProgram = new AvgPool3DBackpropProgram(convInfo); + return this.compileAndRun(avgPool3dBackpropProgram, [dy], x.dtype); + }; + MathBackendWebGL.prototype.maxPool3d = function (x, convInfo) { + var program = new Pool3DProgram(convInfo, 'max', false); + return this.compileAndRun(program, [x], 'float32'); + }; + MathBackendWebGL.prototype.maxPool3dBackprop = function (dy, x, y, convInfo) { + var getPositions = true; + var maxPool3dPositionsProgram = new Pool3DProgram(convInfo, 'max', getPositions); + var maxPool3dPositions = this.compileAndRun(maxPool3dPositionsProgram, [x]); + var maxPool3dBackPropProgram = new MaxPool3DBackpropProgram(convInfo); + var result = this.compileAndRun(maxPool3dBackPropProgram, [dy, maxPool3dPositions], x.dtype); + maxPool3dPositions.dispose(); + return result; + }; + MathBackendWebGL.prototype.reshape = function (x, shape) { + var texData = this.texData.get(x.dataId); + if (texData.isPacked && !isReshapeFree(x.shape, shape) && + !(texData.texture !== null && + isReshapeFree(texData.shape, shape))) { + var info = this.packedReshape(x, shape); + return ENGINE.makeTensorFromDataId(info.dataId, info.shape, info.dtype); + } + return reshapeTensor(x, shape); + }; + MathBackendWebGL.prototype.resizeBilinear = function (x, newHeight, newWidth, alignCorners) { + var program = env().getBool('WEBGL_PACK_IMAGE_OPERATIONS') ? + new ResizeBilinearPackedProgram(x.shape, newHeight, newWidth, alignCorners) : + new ResizeBilinearProgram(x.shape, newHeight, newWidth, alignCorners); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.resizeBilinearBackprop = function (dy, x, alignCorners) { + var program = new ResizeBilinearBackpropProgram(dy, x, alignCorners); + return this.compileAndRun(program, [dy]); + }; + MathBackendWebGL.prototype.resizeNearestNeighbor = function (x, newHeight, newWidth, alignCorners) { + var program = new ResizeNearestNeighborProgram(x.shape, newHeight, newWidth, alignCorners); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.resizeNearestNeighborBackprop = function (dy, x, alignCorners) { + var program = new ResizeNearestNeigborBackpropProgram(dy, x, alignCorners); + return this.compileAndRun(program, [dy]); + }; + MathBackendWebGL.prototype.multinomial = function (logits, normalized, numSamples, seed) { + var probs = normalized ? logits : softmax(logits); + var batchSize = probs.shape[0]; + var numOutcomes = probs.shape[1]; + var program = new MultinomialProgram(batchSize, numOutcomes, numSamples); + var customSetup = program.getCustomSetupFunc(seed); + return this.compileAndRun(program, [probs], 'int32', customSetup); + }; + MathBackendWebGL.prototype.oneHot = function (indices, depth, onValue, offValue) { + var program = new OneHotProgram(indices.size, depth, onValue, offValue); + return this.compileAndRun(program, [indices]); + }; + MathBackendWebGL.prototype.diag = function (x) { + var program = new DiagProgram(x.size); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.nonMaxSuppression = function (boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + warn('tf.nonMaxSuppression() in webgl locks the UI thread. ' + + 'Call tf.nonMaxSuppressionAsync() instead'); + var boxesVals = boxes.dataSync(); + var scoresVals = scores.dataSync(); + return nonMaxSuppressionImpl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + }; + MathBackendWebGL.prototype.cropAndResize = function (image, boxes, boxIndex, cropSize, method, extrapolationValue) { + var program = new CropAndResizeProgram(image.shape, boxes.shape, cropSize, method, extrapolationValue); + return this.compileAndRun(program, [image, boxes, boxIndex]); + }; + MathBackendWebGL.prototype.depthToSpace = function (x, blockSize, dataFormat) { + assert(blockSize > 1, function () { + return "blockSize should be > 1 for depthToSpace, but was: " + blockSize; + }); + var batchSize = x.shape[0]; + var inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; + var inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; + var inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; + var outputHeight = inputHeight * blockSize; + var outputWidth = inputWidth * blockSize; + var outputDepth = inputDepth / (blockSize * blockSize); + var outputShape = (dataFormat === 'NHWC') ? + [batchSize, outputHeight, outputWidth, outputDepth] : + [batchSize, outputDepth, outputHeight, outputWidth]; + var program = new DepthToSpaceProgram(outputShape, blockSize, dataFormat); + return this.compileAndRun(program, [x]); + }; + MathBackendWebGL.prototype.split = function (x, sizeSplits, axis) { + return split$1(x, sizeSplits, axis); + }; + MathBackendWebGL.prototype.scatterND = function (indices, updates, shape) { + var _a = calculateShapes(updates, indices, shape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize; + var flattenShape = [outputSize / sliceSize, sliceSize]; + var flattenIndices = indices.reshape([numUpdates, sliceRank]); + var flattenX = updates.reshape([numUpdates, sliceSize]); + if (outputSize === 0) { + return reshapeTensor(tensor([]), shape); + } + var defaultValue = scalar(0); + var program = new ScatterProgram(numUpdates, sliceRank, flattenIndices.rank, flattenX.rank, strides, flattenShape); + var res = this.compileAndRun(program, [flattenX, flattenIndices, defaultValue]); + return res.reshape(shape); + }; + MathBackendWebGL.prototype.sparseToDense = function (sparseIndices, sparseValues, outputShape, defaultValue) { + var _a = calculateShapes(sparseValues, sparseIndices, outputShape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, strides = _a.strides, outputSize = _a.outputSize; + var sumDupeIndices = false; + var program = new ScatterProgram(numUpdates, sliceRank, sparseIndices.rank, sparseValues.rank, strides, [outputSize, 1], sumDupeIndices); + var res = this.compileAndRun(program, [sparseValues, sparseIndices, defaultValue]); + return res.reshape(outputShape); + }; + MathBackendWebGL.prototype.fft = function (x) { + var inverse = false; + return this.fftImpl(x, inverse); + }; + MathBackendWebGL.prototype.ifft = function (x) { + var inverse = true; + return this.fftImpl(x, inverse); + }; + MathBackendWebGL.prototype.fftImpl = function (x, inverse) { + var xData = this.texData.get(x.dataId); + var realProgram = new FFTProgram(COMPLEX_FFT.REAL, x.shape, inverse); + var imagProgram = new FFTProgram(COMPLEX_FFT.IMAG, x.shape, inverse); + var inputs = [ + this.makeComplexComponentTensorInfo(x, xData.complexTensors.real), + this.makeComplexComponentTensorInfo(x, xData.complexTensors.imag), + ]; + var real = this.compileAndRun(realProgram, inputs); + var imag = this.compileAndRun(imagProgram, inputs); + var complex = this.complex(real, imag).as2D(x.shape[0], x.shape[1]); + real.dispose(); + imag.dispose(); + return complex; + }; + MathBackendWebGL.prototype.gatherND = function (x, indices) { + var indicesShape = indices.shape; + var sliceRank = indicesShape[indicesShape.length - 1]; + var _a = prepareAndValidate(x, indices), resultShape = _a[0], numSlices = _a[1], sliceSize = _a[2], strides = _a[3]; + var flattenIndices = indices.reshape([numSlices, sliceRank]); + var flattenX = x.reshape([x.size / sliceSize, sliceSize]); + var program = new GatherNDProgram(sliceRank, strides, [numSlices, sliceSize]); + var res = this.compileAndRun(program, [flattenX, flattenIndices]); + return res.reshape(resultShape); + }; + MathBackendWebGL.prototype.fill = function (shape, value, dtype) { + dtype = dtype || inferDtype(value); + if (dtype === 'string') { + // String type should be handled in CPU memory. + var values = getArrayFromDType(dtype, sizeFromShape(shape)); + values.fill(value); + return ENGINE.makeTensor(values, shape, dtype, this); + } + else { + var program = new FillProgram(shape, value); + var customSetup = program.getCustomSetupFunc(value); + return this.compileAndRun(program, [], dtype, customSetup); + } + }; + MathBackendWebGL.prototype.onesLike = function (x) { + if (x.dtype === 'string') { + throw new Error('onesLike is not supported under string dtype'); + } + else { + // TODO(cais, smilkov): Add WebGL shader for onesLike: + // https://github.com/tensorflow/tfjs/issues/1293 + return this.fill(x.shape, 1, x.dtype); + } + }; + MathBackendWebGL.prototype.zerosLike = function (x) { + return this.fill(x.shape, x.dtype === 'string' ? '' : 0, x.dtype); + }; + MathBackendWebGL.prototype.linspace = function (start, stop, num) { + // TODO: Use CPU implementation due to the precision problem in Safari. + return linspaceImpl(start, stop, num); + }; + MathBackendWebGL.prototype.makeTensorInfo = function (shape, dtype) { + var dataId = this.write(null /* values */, shape, dtype); + this.texData.get(dataId).usage = null; + return { dataId: dataId, shape: shape, dtype: dtype }; + }; + MathBackendWebGL.prototype.makeOutput = function (shape, dtype) { + var dataId = this.makeTensorInfo(shape, dtype).dataId; + return ENGINE.makeTensorFromDataId(dataId, shape, dtype, this); + }; + MathBackendWebGL.prototype.unpackTensor = function (input) { + var program = new UnpackProgram(input.shape); + return this.runWebGLProgram(program, [input], input.dtype); + }; + MathBackendWebGL.prototype.packTensor = function (input) { + var program = new PackProgram(input.shape); + var preventEagerUnpackingOutput = true; + return this.runWebGLProgram(program, [input], input.dtype, null /* customSetup */, preventEagerUnpackingOutput); + }; + MathBackendWebGL.prototype.packedReshape = function (input, afterShape) { + var input3DShape = [ + getBatchDim(input.shape) + ].concat(getRowsCols(input.shape)); + var input3D = { + dtype: input.dtype, + shape: input3DShape, + dataId: input.dataId + }; + var afterShapeAs3D = [ + getBatchDim(afterShape) + ].concat(getRowsCols(afterShape)); + var program = new ReshapePackedProgram(afterShapeAs3D, input3DShape); + var preventEagerUnpackingOfOutput = true; + var output = this.runWebGLProgram(program, [input3D], input.dtype, null /* customSetup */, preventEagerUnpackingOfOutput); + return { dataId: output.dataId, shape: afterShape, dtype: output.dtype }; + }; + MathBackendWebGL.prototype.decode = function (dataId) { + var texData = this.texData.get(dataId); + var isPacked = texData.isPacked, shape = texData.shape, dtype = texData.dtype; + var shapeAs3D = getShapeAs3D(shape); + var program; + if (isPacked) { + program = new DecodeMatrixPackedProgram(shapeAs3D); + } + else { + program = new DecodeMatrixProgram(shapeAs3D); + } + var preventEagerUnpackingOfOutput = true; + var out = this.runWebGLProgram(program, [{ shape: shapeAs3D, dtype: dtype, dataId: dataId }], dtype, null /* customSetup */, preventEagerUnpackingOfOutput); + return { dtype: dtype, shape: shape, dataId: out.dataId }; + }; + MathBackendWebGL.prototype.runWebGLProgram = function (program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput) { + var _this = this; + if (preventEagerUnpackingOfOutput === void 0) { preventEagerUnpackingOfOutput = false; } + var output = this.makeTensorInfo(program.outputShape, outputDtype); + var outData = this.texData.get(output.dataId); + if (program.packedOutput) { + outData.isPacked = true; + } + if (program.outPackingScheme === PackingScheme.DENSE) { + var texelShape = getDenseTexShape(program.outputShape); + // For a densely packed output, we explicitly set texShape + // so it doesn't get assigned later according to our typical packing + // scheme wherein a single texel can only contain values from adjacent + // rows/cols. + outData.texShape = texelShape.map(function (d) { return d * 2; }); + } + if (program.outTexUsage != null) { + outData.usage = program.outTexUsage; + } + if (sizeFromShape(output.shape) === 0) { + // Short-circuit the computation since the result is empty (has 0 in its + // shape). + outData.values = getTypedArrayFromDType(output.dtype, 0); + return output; + } + var dataToDispose = []; + var inputsData = inputs.map(function (input) { + if (input.dtype === 'complex64') { + throw new Error("GPGPUProgram does not support complex64 input. For complex64 " + + "dtypes, please separate the program into real and imaginary " + + "parts."); + } + var texData = _this.texData.get(input.dataId); + if (texData.texture == null) { + if (!program.packedInputs && + sizeFromShape(input.shape) <= + env().getNumber('WEBGL_SIZE_UPLOAD_UNIFORM')) { + // Upload small tensors that live on the CPU as uniforms, not as + // textures. Do this only when the environment supports 32bit floats + // due to problems when comparing 16bit floats with 32bit floats. + // TODO(https://github.com/tensorflow/tfjs/issues/821): Make it + // possible for packed shaders to sample from uniforms. + return { + shape: input.shape, + texData: null, + isUniform: true, + uniformValues: texData.values + }; + } + // This ensures that if a packed program's inputs have not yet been + // uploaded to the GPU, they get uploaded as packed right off the bat. + if (program.packedInputs) { + texData.isPacked = true; + texData.shape = input.shape; + } + } + else if (!!texData.isPacked !== !!program.packedInputs) { + input = texData.isPacked ? _this.unpackTensor(input) : + _this.packTensor(input); + dataToDispose.push(input); + texData = _this.texData.get(input.dataId); + } + else if (texData.isPacked && + !isReshapeFree(texData.shape, input.shape)) { + // This is a special case where a texture exists for a tensor + // but the shapes are incompatible (due to packing constraints) because + // the tensor did not have a chance to go through the packed reshape + // shader. This only happens when we reshape the *same* tensor to form + // *distinct* inputs to an op, e.g. dotting a vector with itself. This + // case will disappear once packed uploading is the default. + var savedInput = input; + var targetShape = input.shape; + input.shape = texData.shape; + input = _this.packedReshape(input, targetShape); + dataToDispose.push(input); + texData = _this.texData.get(input.dataId); + savedInput.shape = targetShape; + } + _this.uploadToGPU(input.dataId); + return { shape: input.shape, texData: texData, isUniform: false }; + }); + this.uploadToGPU(output.dataId); + var outputData = { shape: output.shape, texData: outData, isUniform: false }; + var key = makeShaderKey(program, inputsData, outputData); + var binary = this.getAndSaveBinary(key, function () { + return compileProgram(_this.gpgpu, program, inputsData, outputData); + }); + var shouldTimeProgram = this.activeTimers != null; + var query; + if (shouldTimeProgram) { + query = this.startTimer(); + } + runProgram(this.gpgpu, binary, inputsData, outputData, customSetup); + dataToDispose.forEach(function (info) { return _this.disposeData(info.dataId); }); + if (shouldTimeProgram) { + query = this.endTimer(query); + this.activeTimers.push({ name: program.constructor.name, query: this.getQueryTime(query) }); + } + if (!env().getBool('WEBGL_LAZILY_UNPACK') && outData.isPacked && + preventEagerUnpackingOfOutput === false) { + var unpacked = this.unpackTensor(output); + this.disposeData(output.dataId); + return unpacked; + } + return output; + }; + MathBackendWebGL.prototype.compileAndRun = function (program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput) { + if (preventEagerUnpackingOfOutput === void 0) { preventEagerUnpackingOfOutput = false; } + outputDtype = outputDtype || inputs[0].dtype; + var outInfo = this.runWebGLProgram(program, inputs, outputDtype, customSetup, preventEagerUnpackingOfOutput); + return ENGINE.makeTensorFromDataId(outInfo.dataId, outInfo.shape, outInfo.dtype); + }; + MathBackendWebGL.prototype.getAndSaveBinary = function (key, getBinary) { + if (!(key in this.binaryCache)) { + this.binaryCache[key] = getBinary(); + } + return this.binaryCache[key]; + }; + MathBackendWebGL.prototype.getTextureManager = function () { + return this.textureManager; + }; + MathBackendWebGL.prototype.dispose = function () { + if (this.disposed) { + return; + } + this.textureManager.dispose(); + if (this.canvas != null && + (typeof (HTMLCanvasElement) !== 'undefined' && + this.canvas instanceof HTMLCanvasElement)) { + this.canvas.remove(); + } + else { + this.canvas = null; + } + if (this.fromPixels2DContext != null && + //@ts-ignore + this.fromPixels2DContext.canvas.remove) { + //@ts-ignore + this.fromPixels2DContext.canvas.remove(); + } + if (this.gpgpuCreatedLocally) { + this.gpgpu.program = null; + this.gpgpu.dispose(); + } + this.disposed = true; + }; + MathBackendWebGL.prototype.floatPrecision = function () { + var _this = this; + if (this.floatPrecisionValue == null) { + this.floatPrecisionValue = tidy(function () { + if (!env().get('WEBGL_RENDER_FLOAT32_ENABLED')) { + // Momentarily switching DEBUG flag to false so we don't throw an + // error trying to upload a small value. + var debugFlag = env().getBool('DEBUG'); + env().set('DEBUG', false); + var underflowCheckValue = _this.abs(scalar(1e-8)).dataSync()[0]; + env().set('DEBUG', debugFlag); + if (underflowCheckValue > 0) { + return 32; + } + } + return 16; + }); + } + return this.floatPrecisionValue; + }; + /** Returns the smallest representable number. */ + MathBackendWebGL.prototype.epsilon = function () { + return this.floatPrecision() === 32 ? EPSILON_FLOAT32 : EPSILON_FLOAT16; + }; + MathBackendWebGL.prototype.uploadToGPU = function (dataId) { + var _a; + var texData = this.texData.get(dataId); + var shape = texData.shape, dtype = texData.dtype, values = texData.values, texture = texData.texture, usage = texData.usage, isPacked = texData.isPacked; + if (texture != null) { + // Array is already on GPU. No-op. + return; + } + var shouldTimeProgram = this.activeTimers != null; + var start; + if (shouldTimeProgram) { + start = now(); + } + var texShape = texData.texShape; + if (texShape == null) { + texShape = getTextureShapeFromLogicalShape(shape, isPacked); + texData.texShape = texShape; + } + if (values != null) { + var shapeAs3D = getShapeAs3D(shape); + var program = void 0; + var width = texShape[1], height = texShape[0]; + var isByteArray = values instanceof Uint8Array; + if (isPacked) { + _a = getPackedMatrixTextureShapeWidthHeight(texShape[0], texShape[1]), width = _a[0], height = _a[1]; + program = new EncodeMatrixPackedProgram(shapeAs3D, [height, width], isByteArray); + } + else { + program = + new EncodeMatrixProgram(shapeAs3D, [height, width], isByteArray); + } + var tempDenseInputHandle = this.makeTensorInfo([height, width], dtype); + if (isByteArray) { + this.texData.get(tempDenseInputHandle.dataId).usage = + TextureUsage.PIXELS; + } + else { + this.texData.get(tempDenseInputHandle.dataId).usage = + TextureUsage.UPLOAD; + } + this.gpgpu.uploadDenseMatrixToTexture(this.getTexture(tempDenseInputHandle.dataId), width, height, values); + // We want the output to remain packed regardless of the value of + // WEBGL_PACK. + var preventEagerUnpacking = true; + var encodedOutputTarget = this.runWebGLProgram(program, [tempDenseInputHandle], dtype, null, preventEagerUnpacking); + // Have the original texture assume the identity of the encoded output. + var outputTexData = this.texData.get(encodedOutputTarget.dataId); + texData.texture = outputTexData.texture; + texData.texShape = outputTexData.texShape; + texData.isPacked = outputTexData.isPacked; + texData.usage = outputTexData.usage; + this.disposeData(tempDenseInputHandle.dataId); + this.texData.delete(encodedOutputTarget.dataId); + // Once uploaded, don't store the values on cpu. + texData.values = null; + if (shouldTimeProgram) { + this.uploadWaitMs += now() - start; + } + } + else { + var newTexture = this.acquireTexture(texShape, usage, dtype, isPacked); + texData.texture = newTexture; + } + }; + MathBackendWebGL.prototype.convertAndCacheOnCPU = function (dataId, float32Values) { + var texData = this.texData.get(dataId); + var dtype = texData.dtype; + this.releaseGPUData(dataId); + if (float32Values != null) { + texData.values = float32ToTypedArray(float32Values, dtype); + } + return texData.values; + }; + MathBackendWebGL.prototype.acquireTexture = function (texShape, texType, dtype, isPacked) { + this.numBytesInGPU += this.computeBytes(texShape, dtype); + if (!this.warnedAboutMemory && + this.numBytesInGPU > this.numMBBeforeWarning * 1024 * 1024) { + var mb = (this.numBytesInGPU / 1024 / 1024).toFixed(2); + this.warnedAboutMemory = true; + console.warn("High memory usage in GPU: " + mb + " MB, " + + "most likely due to a memory leak"); + } + return this.textureManager.acquireTexture(texShape, texType, isPacked); + }; + MathBackendWebGL.prototype.computeBytes = function (shape, dtype) { + return shape[0] * shape[1] * bytesPerElement(dtype); + }; + return MathBackendWebGL; + }(KernelBackend)); + if (isBrowser()) { + ENGINE.registerBackend('webgl', function () { return new MathBackendWebGL(); }, 2 /* priority */); + } + function float32ToTypedArray(a, dtype) { + if (dtype === 'float32' || dtype === 'complex64') { + return a; + } + else if (dtype === 'int32' || dtype === 'bool') { + var result = (dtype === 'int32') ? new Int32Array(a.length) : + new Uint8Array(a.length); + for (var i = 0; i < result.length; ++i) { + result[i] = Math.round(a[i]); + } + return result; + } + else { + throw new Error("Unknown dtype " + dtype); + } + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes square of `x` element-wise: `x ^ 2` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.sqrt(2), -1]); + * + * x.square().print(); // or tf.square(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function square_(x) { + var $x = convertToTensor(x, 'x', 'square'); + var grad = function (dy, saved) { + var x = saved[0]; + return { x: function () { return dy.mul(x.toFloat().mul(2)); } }; + }; + var kernelName = 'Square'; + var attrs = {}; + var inputsToSave = [$x]; + var outputsToSave = []; + return ENGINE.runKernelFunc(function (backend, save) { + save([$x]); + return backend.square($x); + }, { x: $x }, grad, kernelName, attrs, inputsToSave, outputsToSave); + } + var square = op({ square_: square_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes `-1 * x` element-wise. + * + * ```js + * const x = tf.tensor2d([1, 2, -2, 0], [2, 2]); + * + * x.neg().print(); // or tf.neg(x) + * ``` + * + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function neg_(x) { + var $x = convertToTensor(x, 'x', 'neg'); + var grad = function (dy) { + return { $x: function () { return dy.neg(); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.neg($x); }, { $x: $x }, grad); + } + /** + * Computes ceiling of input `tf.Tensor` element-wise: `ceil(x)` + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.ceil().print(); // or tf.ceil(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function ceil_(x) { + var $x = convertToTensor(x, 'x', 'ceil'); + // TODO(manrajgrover): Return null for gradients when backprop supports it. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.ceil($x); }, { $x: $x }, grad); + } + /** + * Computes floor of input `tf.Tensor` element-wise: `floor(x)`. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.floor().print(); // or tf.floor(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function floor_(x) { + var $x = convertToTensor(x, 'x', 'floor'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.floor($x); }, { $x: $x }, grad); + } + /** + * Returns an element-wise indication of the sign of a number. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3, NaN, 0]); + * + * x.sign().print(); // or tf.sign(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sign_(x) { + var $x = convertToTensor(x, 'x', 'sign'); + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.sign($x); }, { $x: $x }, grad); + } + /** + * RReturns which elements of x are NaN. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isNaN().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function isNaN_(x) { + var $x = convertToTensor(x, 'x', 'isNaN'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.isNaN($x); }, { $x: $x }, grad); + } + /** + * Returns which elements of x are Infinity or -Infinity. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isInf().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function isInf_(x) { + var $x = convertToTensor(x, 'x', 'isInf'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.isInf($x); }, { $x: $x }, grad); + } + /** + * Returns which elements of x are finite. + * + * ```js + * const x = tf.tensor1d([NaN, Infinity, -Infinity, 0, 1]); + * + * x.isFinite().print(); // or tf.isNaN(x) + * ``` + * @param x The input Tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function isFinite_(x) { + var $x = convertToTensor(x, 'x', 'isFinite'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.isFinite($x); }, { $x: $x }, grad); + } + /** + * Computes round of input `tf.Tensor` element-wise: `round(x)`. + * It implements banker's rounding. + * + * ```js + * const x = tf.tensor1d([.6, 1.1, -3.3]); + * + * x.round().print(); // or tf.round(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function round_(x) { + var $x = convertToTensor(x, 'x', 'round'); + // TODO(nsthorat): Let gradients be null for cases where we want to stop + // backpropgation. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.round($x); }, { $x: $x }, grad); + } + /** + * Computes exponential of the input `tf.Tensor` element-wise. `e ^ x` + * + * ```js + * const x = tf.tensor1d([1, 2, -3]); + * + * x.exp().print(); // or tf.exp(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function exp_(x) { + var $x = convertToTensor(x, 'x', 'exp'); + var bck = function (dy, saved) { + return { $x: function () { return dy.mulStrict(saved[0]); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.exp($x); + save([y]); + return y; + }, { $x: $x }, bck); + } + /** + * Computes exponential of the input `tf.Tensor` minus one element-wise. + * `e ^ x - 1` + * + * ```js + * const x = tf.tensor1d([1, 2, -3]); + * + * x.expm1().print(); // or tf.expm1(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function expm1_(x) { + var $x = convertToTensor(x, 'x', 'expm1'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.mul($x.exp()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.expm1($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes natural logarithm of the input `tf.Tensor` element-wise: `ln(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.E]); + * + * x.log().print(); // or tf.log(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function log_(x) { + var $x = convertToTensor(x, 'x', 'log'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.toFloat()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.log($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes natural logarithm of the input `tf.Tensor` plus one + * element-wise: `ln(1 + x)` + * + * ```js + * const x = tf.tensor1d([1, 2, Math.E - 1]); + * + * x.log1p().print(); // or tf.log1p(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function log1p_(x) { + var $x = convertToTensor(x, 'x', 'log1p'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.add(1)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.log1p($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes square root of the input `tf.Tensor` element-wise: `y = sqrt(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, 4, -1]); + * + * x.sqrt().print(); // or tf.sqrt(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sqrt_(x) { + var $x = convertToTensor(x, 'x', 'sqrt'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.toFloat().sqrt().mul(2)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.sqrt($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes reciprocal of square root of the input `tf.Tensor` element-wise: + * `y = 1 / sqrt(x)` + * + * ```js + * const x = tf.tensor1d([1, 2, 4, -1]); + * + * x.rsqrt().print(); // or tf.rsqrt(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function rsqrt_(x) { + var $x = convertToTensor(x, 'x', 'rsqrt'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.pow(1.5).mul(2)).neg(); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.rsqrt($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes reciprocal of x element-wise: `1 / x` + * + * ```js + * const x = tf.tensor1d([0, 1, 2]); + * + * x.reciprocal().print(); // or tf.reciprocal(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function reciprocal_(x) { + var $x = convertToTensor(x, 'x', 'reciprocal'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.square().neg()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.reciprocal($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes absolute value element-wise: `abs(x)` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.abs().print(); // or tf.abs(x) + * ``` + * @param x The input `tf.Tensor`. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function abs_(x) { + var $x = convertToTensor(x, 'x', 'abs'); + if ($x.dtype === 'complex64') { + return ENGINE.runKernelFunc(function (backend) { return backend.complexAbs($x); }, { $x: $x }); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { x: function () { return dy.mul($x.toFloat().step(-1)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.abs($x); + save([$x]); + return res; + }, { x: $x }, grad, 'Abs'); + } + /** + * Clips values element-wise. `max(min(x, clipValueMax), clipValueMin)` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.clipByValue(-2, 3).print(); // or tf.clipByValue(x, -2, 3) + * ``` + * @param x The input tensor. + * @param clipValueMin Lower-bound of range to be clipped to. + * @param clipValueMax Upper-bound of range to be clipped to. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function clipByValue_(x, clipValueMin, clipValueMax) { + var $x = convertToTensor(x, 'x', 'clipByValue'); + assert((clipValueMin <= clipValueMax), function () { return "Error in clip: min (" + clipValueMin + ") must be " + + ("less than or equal to max (" + clipValueMax + ")."); }); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + x: function () { return dy.where($x.greaterEqual(clipValueMin) + .logicalAnd($x.lessEqual(clipValueMax)), zerosLike(dy)); }, + }; + }; + var inputsToSave = [$x]; + var attr = { min: clipValueMin, max: clipValueMax }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.clip($x, clipValueMin, clipValueMax); + save([$x]); + return res; + }, { x: $x }, grad, 'ClipByValue', attr, inputsToSave); + } + /** + * Computes sigmoid element-wise, `1 / (1 + exp(-x))` + * + * ```js + * const x = tf.tensor1d([0, -1, 2, -3]); + * + * x.sigmoid().print(); // or tf.sigmoid(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sigmoid_(x) { + var $x = convertToTensor(x, 'x', 'sigmoid'); + var grad = function (dy, saved) { + var y = saved[0]; + return { x: function () { return dy.mul(y.mul(scalar(1).sub(y))); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.sigmoid($x); + save([y]); + return y; + }, { x: $x }, grad, 'Sigmoid'); + } + /** + * Computes log sigmoid of the input `tf.Tensor` element-wise: + * `logSigmoid(x)`. For numerical stability, we use `-tf.softplus(-x)`. + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.logSigmoid().print(); // or tf.logSigmoid(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function logSigmoid_(x) { + var $x = convertToTensor(x, 'x', 'logSigmoid'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.mul($x.neg().sigmoid()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.softplus($x.neg()).neg(); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes softplus of the input `tf.Tensor` element-wise: `log(exp(x) + 1)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.softplus().print(); // or tf.softplus(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function softplus_(x) { + var $x = convertToTensor(x, 'x', 'softplus'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.mul($x.sigmoid()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.softplus($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes sin of the input Tensor element-wise: `sin(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.sin().print(); // or tf.sin(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sin_(x) { + var $x = convertToTensor(x, 'x', 'sin'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return $x.toFloat().cos().mul(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.sin($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes cos of the input `tf.Tensor` element-wise: `cos(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.cos().print(); // or tf.cos(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function cos_(x) { + var $x = convertToTensor(x, 'x', 'cos'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return $x.toFloat().sin().neg().mul(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.cos($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes tan of the input `tf.Tensor` element-wise, `tan(x)` + * + * ```js + * const x = tf.tensor1d([0, Math.PI / 2, Math.PI * 3 / 4]); + * + * x.tan().print(); // or tf.tan(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function tan_(x) { + var $x = convertToTensor(x, 'x', 'tan'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.cos().square()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.tan($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes asin of the input `tf.Tensor` element-wise: `asin(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.asin().print(); // or tf.asin(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function asin_(x) { + var $x = convertToTensor(x, 'x', 'asin'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { return dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt()); } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.asin($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes acos of the input `tf.Tensor` element-wise: `acos(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.acos().print(); // or tf.acos(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function acos_(x) { + var $x = convertToTensor(x, 'x', 'acos'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { + return dy.divStrict(scalar(1).sub($x.toFloat().square()).sqrt()).neg(); + } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.acos($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes atan of the input `tf.Tensor` element-wise: `atan(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.atan().print(); // or tf.atan(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function atan_(x) { + var $x = convertToTensor(x, 'x', 'atan'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div($x.toFloat().square().add(1)); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.atan($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes hyperbolic sin of the input `tf.Tensor` element-wise: `sinh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.sinh().print(); // or tf.sinh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function sinh_(x) { + var $x = convertToTensor(x, 'x', 'sinh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return $x.toFloat().cosh().mulStrict(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.sinh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes hyperbolic cos of the input `tf.Tensor` element-wise: `cosh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.cosh().print(); // or tf.cosh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function cosh_(x) { + var $x = convertToTensor(x, 'x', 'cosh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return $x.toFloat().sinh().mulStrict(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.cosh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes hyperbolic tangent of the input `tf.Tensor` element-wise: `tanh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, 70]); + * + * x.tanh().print(); // or tf.tanh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function tanh_(x) { + var $x = convertToTensor(x, 'x', 'tanh'); + var grad = function (dy, saved) { + var y = saved[0]; + return { $x: function () { return scalar(1).sub(y.square()).mulStrict(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.tanh($x); + save([y]); + return y; + }, { $x: $x }, grad); + } + /** + * Computes inverse hyperbolic sin of the input `tf.Tensor` element-wise: + * `asinh(x)` + * + * ```js + * const x = tf.tensor1d([0, 1, -1, .7]); + * + * x.asinh().print(); // or tf.asinh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function asinh_(x) { + var $x = convertToTensor(x, 'x', 'asinh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { return dy.divStrict(scalar(1).add($x.toFloat().square()).sqrt()); } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.asinh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes the inverse hyperbolic cos of the input `tf.Tensor` element-wise: + * `acosh(x)` + * + * ```js + * const x = tf.tensor1d([10, 1, 3, 5.7]); + * + * x.acosh().print(); // or tf.acosh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function acosh_(x) { + var $x = convertToTensor(x, 'x', 'acosh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.divStrict($x.toFloat().square().sub(1).sqrt()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.acosh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes inverse hyperbolic tan of the input `tf.Tensor` element-wise: + * `atanh(x)` + * + * ```js + * const x = tf.tensor1d([0, .1, -.1, .7]); + * + * x.atanh().print(); // or tf.atanh(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function atanh_(x) { + var $x = convertToTensor(x, 'x', 'atanh'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.div(scalar(1).sub($x.toFloat().square())); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.atanh($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes gause error function of the input `tf.Tensor` element-wise: + * `erf(x)` + * + * ```js + * const x = tf.tensor1d([0, .1, -.1, .7]); + * + * x.erf().print(); // or tf.erf(x); + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function erf_(x) { + var $x = convertToTensor(x, 'x', 'erf'); + assert($x.dtype === 'int32' || $x.dtype === 'float32', function () { return 'Input dtype must be `int32` or `float32`.'; }); + if ($x.dtype === 'int32') { + $x = $x.toFloat(); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { return dy.mul($x.square().neg().exp().mul(2 / Math.sqrt(Math.PI))); } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.erf($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes step of the input `tf.Tensor` element-wise: `x > 0 ? 1 : alpha * x` + * + * ```js + * const x = tf.tensor1d([0, 2, -1, -3]); + * + * x.step(.5).print(); // or tf.step(x, .5) + * ``` + * @param x The input tensor. + * @param alpha The gradient when input is negative. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function step_(x, alpha) { + if (alpha === void 0) { alpha = 0.0; } + var $x = convertToTensor(x, 'x', 'step'); + // TODO(manrajgrover): Return null for gradients when backprop supports + // it. + var grad = function (dy) { + return { $x: function () { return zerosLike(dy); } }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.step($x, alpha); }, { $x: $x }, grad); + } + var abs = op({ abs_: abs_ }); + var acos = op({ acos_: acos_ }); + var acosh = op({ acosh_: acosh_ }); + var asin = op({ asin_: asin_ }); + var asinh = op({ asinh_: asinh_ }); + var atan = op({ atan_: atan_ }); + var atanh = op({ atanh_: atanh_ }); + var ceil = op({ ceil_: ceil_ }); + var clipByValue = op({ clipByValue_: clipByValue_ }); + var cos = op({ cos_: cos_ }); + var cosh = op({ cosh_: cosh_ }); + var erf = op({ erf_: erf_ }); + var exp = op({ exp_: exp_ }); + var expm1 = op({ expm1_: expm1_ }); + var floor = op({ floor_: floor_ }); + var log = op({ log_: log_ }); + var log1p = op({ log1p_: log1p_ }); + var logSigmoid = op({ logSigmoid_: logSigmoid_ }); + var neg = op({ neg_: neg_ }); + var reciprocal = op({ reciprocal_: reciprocal_ }); + var round = op({ round_: round_ }); + var rsqrt = op({ rsqrt_: rsqrt_ }); + var sigmoid = op({ sigmoid_: sigmoid_ }); + var sign = op({ sign_: sign_ }); + var isNaN$1 = op({ isNaN_: isNaN_ }); + var isInf = op({ isInf_: isInf_ }); + var isFinite$1 = op({ isFinite_: isFinite_ }); + var sin = op({ sin_: sin_ }); + var sinh = op({ sinh_: sinh_ }); + var softplus = op({ softplus_: softplus_ }); + var sqrt = op({ sqrt_: sqrt_ }); + var step = op({ step_: step_ }); + var tan = op({ tan_: tan_ }); + var tanh$1 = op({ tanh_: tanh_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Batch normalization, strictly for 2D. For the more relaxed version, see + * `tf.batchNorm`. + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ + function batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon) { + var $x = convertToTensor(x, 'x', 'batchNorm'); + var $mean = convertToTensor(mean, 'mean', 'batchNorm'); + var $variance = convertToTensor(variance, 'variance', 'batchNorm'); + var $scale; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + var $offset; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + assert($x.rank === 2, function () { return "Error in batchNorm3D: x must be rank 3 but got rank " + + ($x.rank + "."); }); + assert($mean.rank === 2 || $mean.rank === 1, function () { return "Error in batchNorm2D: mean must be rank 2 or rank 1 but " + + ("got rank " + $mean.rank + "."); }); + assert($variance.rank === 2 || $variance.rank === 1, function () { return "Error in batchNorm2D: variance must be rank 2 or rank 1 " + + ("but got rank " + $variance.rank + "."); }); + if ($scale != null) { + assert($scale.rank === 2 || $scale.rank === 1, function () { return "Error in batchNorm2D: scale must be rank 2 or rank 1 " + + ("but got rank " + $scale.rank + "."); }); + } + if ($offset != null) { + assert($offset.rank === 2 || $offset.rank === 1, function () { return "Error in batchNorm2D: offset must be rank 2 or rank 1 " + + ("but got rank " + $offset.rank + "."); }); + } + return batchNorm_($x, $mean, $variance, $offset, $scale, varianceEpsilon); + } + /** + * Batch normalization, strictly for 3D. For the more relaxed version, see + * `tf.batchNorm`. + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ + function batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon) { + var $x = convertToTensor(x, 'x', 'batchNorm'); + var $mean = convertToTensor(mean, 'mean', 'batchNorm'); + var $variance = convertToTensor(variance, 'variance', 'batchNorm'); + var $scale; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + var $offset; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + assert($x.rank === 3, function () { return "Error in batchNorm3D: x must be rank 3 but got rank " + + ($x.rank + "."); }); + assert($mean.rank === 3 || $mean.rank === 1, function () { return "Error in batchNorm3D: mean must be rank 3 or rank 1 but " + + ("got rank " + $mean.rank + "."); }); + assert($variance.rank === 3 || $variance.rank === 1, function () { return "Error in batchNorm3D: variance must be rank 3 or rank 1 " + + ("but got rank " + $variance.rank + "."); }); + if ($scale != null) { + assert($scale.rank === 3 || $scale.rank === 1, function () { return "Error in batchNorm3D: scale must be rank 3 or rank 1 " + + ("but got rank " + $scale.rank + "."); }); + } + if ($offset != null) { + assert($offset.rank === 3 || $offset.rank === 1, function () { return "Error in batchNorm3D: offset must be rank 3 or rank 1 " + + ("but got rank " + $offset.rank + "."); }); + } + return batchNorm_($x, $mean, $variance, $offset, $scale, varianceEpsilon); + } + /** + * Batch normalization, strictly for 4D. For the more relaxed version, see + * `tf.batchNorm`. + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ + function batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon) { + var $x = convertToTensor(x, 'x', 'batchNorm'); + var $mean = convertToTensor(mean, 'mean', 'batchNorm'); + var $variance = convertToTensor(variance, 'variance', 'batchNorm'); + var $scale; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + var $offset; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + assert($x.rank === 4, function () { return "Error in batchNorm4D: x must be rank 4 but got rank " + + ($x.rank + "."); }); + assert($mean.rank === 4 || $mean.rank === 1, function () { return "Error in batchNorm4D: mean must be rank 4 or rank 1 but " + + ("got rank " + $mean.rank + "."); }); + assert($variance.rank === 4 || $variance.rank === 1, function () { return "Error in batchNorm4D: variance must be rank 4 or rank 1 " + + ("but got rank " + $variance.rank + "."); }); + if ($scale != null) { + assert($scale.rank === 4 || $scale.rank === 1, function () { return "Error in batchNorm4D: scale must be rank 4 or rank 1 " + + ("but got rank " + $scale.rank + "."); }); + } + if ($offset != null) { + assert($offset.rank === 4 || $offset.rank === 1, function () { return "Error in batchNorm4D: offset must be rank 4 or rank 1 " + + ("but got rank " + $offset.rank + "."); }); + } + return batchNorm_($x, $mean, $variance, $offset, $scale, varianceEpsilon); + } + /** + * @deprecated Please use `tf.batchNorm` instead and note the positional + * argument change of scale, offset, and varianceEpsilon. + */ + function batchNormalization_(x, mean, variance, varianceEpsilon, scale, offset) { + if (varianceEpsilon === void 0) { varianceEpsilon = .001; } + warnDeprecation(); + return batchNorm_(x, mean, variance, offset, scale, varianceEpsilon); + } + /** + * Batch normalization. + * + * As described in + * [http://arxiv.org/abs/1502.03167](http://arxiv.org/abs/1502.03167). + * + * Mean, variance, scale, and offset can be of two shapes: + * - The same shape as the input. + * - In the common case, the depth dimension is the last dimension of x, so + * the values would be an `tf.Tensor1D` of shape [depth]. + * + * Also available are stricter rank-specific methods with the same signature + * as this method that assert that parameters passed are of given rank + * - `tf.batchNorm2d` + * - `tf.batchNorm3d` + * - `tf.batchNorm4d` + * + * @param x The input Tensor. + * @param mean A mean Tensor. + * @param variance A variance Tensor. + * @param offset An offset Tensor. + * @param scale A scale Tensor. + * @param varianceEpsilon A small float number to avoid dividing by 0. + */ + /** @doc {heading: 'Operations', subheading: 'Normalization'} */ + function batchNorm_(x, mean, variance, offset, scale, varianceEpsilon) { + if (varianceEpsilon == null) { + varianceEpsilon = 0.001; + } + var $x = convertToTensor(x, 'x', 'batchNorm'); + var $mean = convertToTensor(mean, 'mean', 'batchNorm'); + var $variance = convertToTensor(variance, 'variance', 'batchNorm'); + var $scale; + if (scale != null) { + $scale = convertToTensor(scale, 'scale', 'batchNorm'); + } + var $offset; + if (offset != null) { + $offset = convertToTensor(offset, 'offset', 'batchNorm'); + } + assert($mean.rank === $variance.rank, function () { return 'Batch normalization gradient requires mean and variance to have ' + + 'equal ranks.'; }); + assert($offset == null || $mean.rank === $offset.rank, function () { return 'Batch normalization gradient requires mean and offset to have ' + + 'equal ranks.'; }); + assert($scale == null || $mean.rank === $scale.rank, function () { return 'Batch normalization gradient requires mean and scale to have ' + + 'equal ranks.'; }); + var x4D; + if ($x.rank === 0 || $x.rank === 1) { + x4D = $x.as4D(1, 1, 1, $x.size); + } + else if ($x.rank === 2) { + x4D = $x.as4D(1, 1, $x.shape[0], $x.shape[1]); + } + else if ($x.rank === 3) { + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + else { + x4D = $x; + } + var der = function (dy, saved) { + var _a = saved, $x = _a[0], $mean = _a[1], $variance = _a[2], $scale = _a[3]; + var scaleValue = $scale == null ? scalar(1) : $scale; + var reductionAxes = getReductionAxes($mean.shape, x4D.shape); + var tileShape = []; + if ($mean.rank === 1) { + for (var i = 0; i < x4D.shape.length - 1; ++i) { + tileShape.push(x4D.shape[i]); + } + tileShape.push(1); + } + var xMinusMean = $x.sub($mean); + var dyTimesScaleValue = dy.mul(scaleValue); + var oneOverSqrtVariance = rsqrt($variance.add(scalar(varianceEpsilon))); + var minusHalfRCube = oneOverSqrtVariance.mul(oneOverSqrtVariance) + .mul(oneOverSqrtVariance) + .mul(scalar(-0.5)); + var derX = function () { + if ($mean.rank === 1) { + return dy + .mul(tile(oneOverSqrtVariance.as4D(1, 1, 1, $mean.shape[0]), tileShape)) + .mul(scaleValue) + .reshape($x.shape); + } + else { + return dy.mul(oneOverSqrtVariance).mul(scaleValue).reshape($x.shape); + } + }; + var derMean = function () { + var meanDer = oneOverSqrtVariance.mul(scalar(-1)).mul(dyTimesScaleValue); + if ($mean.rank === 1) { + meanDer = meanDer.sum(reductionAxes); + } + return meanDer.reshape($mean.shape); + }; + var derVariance = function () { + var varianceDer = minusHalfRCube.mul(xMinusMean).mul(dyTimesScaleValue); + if ($mean.rank === 1) { + varianceDer = varianceDer.sum(reductionAxes); + } + return varianceDer.reshape($mean.shape); + }; + var derScale = function () { + var xMinusMean2TimesRsqrt = xMinusMean.mul(oneOverSqrtVariance); + var scaleDer = dy.mul(xMinusMean2TimesRsqrt); + if ($mean.rank === 1) { + scaleDer = scaleDer.sum(reductionAxes); + } + return scaleDer.reshape($mean.shape); + }; + var derOffset = function () { + var offsetDer = dy; + if ($mean.rank === 1) { + offsetDer = offsetDer.sum(reductionAxes); + } + return offsetDer.reshape($mean.shape); + }; + return { + x: derX, + mean: derMean, + variance: derVariance, + scale: derScale, + offset: derOffset + }; + }; + var inputsToSave = [$x, $mean, $variance, $scale]; + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.batchNormalization(x4D, batchnormReshape4D($mean), batchnormReshape4D($variance), varianceEpsilon, batchnormReshape4D($scale), batchnormReshape4D($offset)); + save([$x, $mean, $variance, $scale]); + return res; + }, { x: $x, mean: $mean, variance: $variance, scale: $scale, offset: $offset }, der, 'BatchNormalization', { varianceEpsilon: varianceEpsilon }, inputsToSave); + return res.reshape($x.shape); + } + function batchnormReshape4D(x) { + if (x == null) { + return null; + } + if (x.rank === 0) { + return x.as1D(); + } + else if (x.rank === 1) { + return x; + } + else if (x.rank === 2) { + return x.as4D(1, 1, x.shape[0], x.shape[1]); + } + else if (x.rank === 3) { + return x.as4D(1, x.shape[0], x.shape[1], x.shape[2]); + } + return x; + } + /** + * @deprecated Please use `tf.batchNorm2d` instead and note the positional + * argument change of scale, offset, and varianceEpsilon. + */ + function batchNormalization2d_(x, mean, variance, varianceEpsilon, scale, offset) { + if (varianceEpsilon === void 0) { varianceEpsilon = .001; } + warnDeprecation(); + return batchNorm2d_(x, mean, variance, offset, scale, varianceEpsilon); + } + /** + * @deprecated Please use `tf.batchNorm3d` instead and note the positional + * argument change of scale, offset, and varianceEpsilon. + */ + function batchNormalization3d_(x, mean, variance, varianceEpsilon, scale, offset) { + if (varianceEpsilon === void 0) { varianceEpsilon = .001; } + warnDeprecation(); + return batchNorm3d_(x, mean, variance, offset, scale, varianceEpsilon); + } + /** + * @deprecated Please use `tf.batchNorm4d` instead and note the positional + * argument change of scale, offset, and varianceEpsilon. + */ + function batchNormalization4d_(x, mean, variance, varianceEpsilon, scale, offset) { + if (varianceEpsilon === void 0) { varianceEpsilon = .001; } + warnDeprecation(); + return batchNorm4d_(x, mean, variance, offset, scale, varianceEpsilon); + } + function warnDeprecation() { + deprecationWarn('tf.batchNormalization() is going away. ' + + 'Use tf.batchNorm() instead, and note the positional argument change ' + + 'of scale, offset, and varianceEpsilon'); + } + var batchNormalization2d = op({ batchNormalization2d_: batchNormalization2d_ }); + var batchNormalization3d = op({ batchNormalization3d_: batchNormalization3d_ }); + var batchNormalization4d = op({ batchNormalization4d_: batchNormalization4d_ }); + var batchNormalization = op({ batchNormalization_: batchNormalization_ }); + var batchNorm = op({ batchNorm_: batchNorm_ }); + var batchNorm2d = op({ batchNorm2d_: batchNorm2d_ }); + var batchNorm3d = op({ batchNorm3d_: batchNorm3d_ }); + var batchNorm4d = op({ batchNorm4d_: batchNorm4d_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns the truth value of `NOT x` element-wise. + * + * ```js + * const a = tf.tensor1d([false, true], 'bool'); + * + * a.logicalNot().print(); + * ``` + * + * @param x The input tensor. Must be of dtype 'bool'. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function logicalNot_(x) { + var $x = convertToTensor(x, 'x', 'logicalNot', 'bool'); + return ENGINE.runKernelFunc(function (backend) { return backend.logicalNot($x); }, { $x: $x }); + } + /** + * Returns the truth value of `a AND b` element-wise. Supports broadcasting. + * + * ```js + * const a = tf.tensor1d([false, false, true, true], 'bool'); + * const b = tf.tensor1d([false, true, false, true], 'bool'); + * + * a.logicalAnd(b).print(); + * ``` + * + * @param a The first input tensor. Must be of dtype bool. + * @param b The second input tensor. Must be of dtype bool. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function logicalAnd_(a, b) { + var $a = convertToTensor(a, 'a', 'logicalAnd', 'bool'); + var $b = convertToTensor(b, 'b', 'logicalAnd', 'bool'); + assertAndGetBroadcastShape($a.shape, $b.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.logicalAnd($a, $b); }, { $a: $a, $b: $b }); + } + /** + * Returns the truth value of `a OR b` element-wise. Supports broadcasting. + * + * ```js + * const a = tf.tensor1d([false, false, true, true], 'bool'); + * const b = tf.tensor1d([false, true, false, true], 'bool'); + * + * a.logicalOr(b).print(); + * ``` + * @param a The first input tensor. Must be of dtype bool. + * @param b The second input tensor. Must be of dtype bool. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function logicalOr_(a, b) { + var $a = convertToTensor(a, 'a', 'logicalOr', 'bool'); + var $b = convertToTensor(b, 'b', 'logicalOr', 'bool'); + assertAndGetBroadcastShape($a.shape, $b.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.logicalOr($a, $b); }, { $a: $a, $b: $b }); + } + /** + * Returns the truth value of `a XOR b` element-wise. Supports broadcasting. + * + * ```js + * const a = tf.tensor1d([false, false, true, true], 'bool'); + * const b = tf.tensor1d([false, true, false, true], 'bool'); + * + * a.logicalXor(b).print(); + * ``` + * + * @param a The first input tensor. Must be of dtype bool. + * @param b The second input tensor. Must be of dtype bool. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function logicalXor_(a, b) { + var $a = convertToTensor(a, 'a', 'logicalXor', 'bool'); + var $b = convertToTensor(b, 'b', 'logicalXor', 'bool'); + assertAndGetBroadcastShape($a.shape, $b.shape); + // x ^ y = (x | y) & ~(x & y) + return logicalOr(a, b).logicalAnd(logicalAnd(a, b).logicalNot()); + } + /** + * Returns the elements, either `a` or `b` depending on the `condition`. + * + * If the condition is true, select from `a`, otherwise select from `b`. + * + * ```js + * const cond = tf.tensor1d([false, false, true], 'bool'); + * const a = tf.tensor1d([1 , 2, 3]); + * const b = tf.tensor1d([-1, -2, -3]); + * + * a.where(cond, b).print(); + * ``` + * + * @param condition The input condition. Must be of dtype bool. + * @param a If `condition` is rank 1, `a` may have a higher rank but + * its first dimension must match the size of `condition`. + * @param b A tensor with the same shape and type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function where_(condition, a, b) { + var $a = convertToTensor(a, 'a', 'where'); + var $b = convertToTensor(b, 'b', 'where'); + var $condition = convertToTensor(condition, 'condition', 'where', 'bool'); + assertShapesMatch($a.shape, $b.shape, 'Error in where: '); + if ($condition.rank === 1) { + // If condition rank is 1, then the first dimension must match the size of + // condition. + assert($condition.shape[0] === $a.shape[0], function () { return 'The first dimension of `a` must match the size of `condition`.'; }); + } + else { + // A must have the same shape as condition. + assertShapesMatch($condition.shape, $b.shape, 'Error in where: '); + } + // TODO(julianoks): Return null for condition gradient + // when backprop supports it. + var grad = function (dy, saved) { + var $condition = saved[0]; + return { + $condition: function () { return zerosLike($condition).toFloat(); }, + $a: function () { return dy.mul($condition.cast(dy.dtype)); }, + $b: function () { return dy.mul($condition.logicalNot().cast(dy.dtype)); } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.select($condition, $a, $b); + save([$condition]); + return res; + }, { $condition: $condition, $a: $a, $b: $b }, grad); + } + /** + * Returns the coordinates of true elements of condition. + * + * The coordinates are returned in a 2-D tensor where the first dimension (rows) + * represents the number of true elements, and the second dimension (columns) + * represents the coordinates of the true elements. Keep in mind, the shape of + * the output tensor can vary depending on how many true values there are in + * input. Indices are output in row-major order. The resulting tensor has the + * shape `[numTrueElems, condition.rank]`. + * + * This is analogous to calling the python `tf.where(cond)` without an x or y. + * + * ```js + * const cond = tf.tensor1d([false, false, true], 'bool'); + * const result = await tf.whereAsync(cond); + * result.print(); + * ``` + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function whereAsync_(condition) { + return __awaiter(this, void 0, void 0, function () { + var $condition, vals, res; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + $condition = convertToTensor(condition, 'condition', 'whereAsync', 'bool'); + return [4 /*yield*/, $condition.data()]; + case 1: + vals = _a.sent(); + res = whereImpl($condition.shape, vals); + if (condition !== $condition) { + $condition.dispose(); + } + return [2 /*return*/, res]; + } + }); + }); + } + var logicalAnd = op({ logicalAnd_: logicalAnd_ }); + var logicalNot = op({ logicalNot_: logicalNot_ }); + var logicalOr = op({ logicalOr_: logicalOr_ }); + var logicalXor = op({ logicalXor_: logicalXor_ }); + var where = op({ where_: where_ }); + var whereAsync = whereAsync_; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Adds two `tf.Tensor`s element-wise, A + B. Supports broadcasting. + * + * We also expose `tf.addStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3, 4]); + * const b = tf.tensor1d([10, 20, 30, 40]); + * + * a.add(b).print(); // or tf.add(a, b) + * ``` + * + * ```js + * // Broadcast add a with b. + * const a = tf.scalar(5); + * const b = tf.tensor1d([10, 20, 30, 40]); + * + * a.add(b).print(); // or tf.add(a, b) + * ``` + * @param a The first `tf.Tensor` to add. + * @param b The second `tf.Tensor` to add. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function add_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'add'); + var $b = convertToTensor(b, 'b', 'add'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy) { + var derA = function () { + var res = dy; + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($a.shape); + }; + var derB = function () { + var res = dy; + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($b.shape); + }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.add($a, $b); }, { a: $a, b: $b }, der, 'Add'); + } + /** + * Adds a list of `tf.Tensor`s element-wise, each with the same shape and dtype. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor1d([3, 4]); + * const c = tf.tensor1d([5, 6]); + * + * tf.addN([a, b, c]).print(); + * ``` + * @param tensors A list of tensors with the same shape and dtype. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function addN_(tensors) { + assert(Array.isArray(tensors), function () { return 'The argument passed to tf.addN() must be a list of tensors'; }); + assert(tensors.length >= 1, function () { return "Must pass at least one tensor to tf.addN(), but got " + + ("" + tensors.length); }); + var $tensors = tensors.map(function (t, i) { return convertToTensor(t, "tensors" + i, 'addN'); }); + var firstTensor = $tensors[0]; + $tensors.forEach(function (t) { + if (t.dtype !== firstTensor.dtype) { + throw new Error('All tensors passed to tf.addN() must have the same dtype'); + } + }); + $tensors.forEach(function (t) { + if (!arraysEqual(t.shape, firstTensor.shape)) { + throw new Error('All tensors passed to tf.addN() must have the same shape'); + } + }); + var der = function (dy) { + var ders = {}; + $tensors.forEach(function (t, i) { + ders[i] = function () { return dy.clone(); }; + }); + return ders; + }; + var inputs = $tensors; + return ENGINE.runKernelFunc(function (backend) { return backend.addN($tensors); }, inputs, der, 'AddN'); + } + /** + * Adds two `tf.Tensor`s element-wise, A + B. + * + * Inputs must be the same shape. For broadcasting support, use add() instead. + * + * @param a The first Tensor to add element-wise. + * @param b The second Tensor to add element-wise. + */ + function addStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'addStrict'); + var $b = convertToTensor(b, 'b', 'addStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in addStrict: '); + return $a.add($b); + } + /** + * Subtracts two `tf.Tensor`s element-wise, A - B. Supports broadcasting. + * + * We also expose `tf.subStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([10, 20, 30, 40]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.sub(b).print(); // or tf.sub(a, b) + * ``` + * + * ```js + * // Broadcast subtract a with b. + * const a = tf.tensor1d([10, 20, 30, 40]); + * const b = tf.scalar(5); + * + * a.sub(b).print(); // or tf.sub(a, b) + * ``` + * @param a The first `tf.Tensor` to subtract from. + * @param b The second `tf.Tensor` to be subtracted. Must have the same dtype as + * `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function sub_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'sub'); + var $b = convertToTensor(b, 'b', 'sub'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy) { + var derA = function () { + var res = dy; + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($a.shape); + }; + var derB = function () { + var res = dy; + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.neg().reshape($b.shape); + }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend) { return backend.subtract($a, $b); }, { a: $a, b: $b }, der, 'Sub'); + } + /** + * Subtracts two `tf.Tensor`s element-wise, A - B. Inputs must + * be the same shape. + * + * For broadcasting support, use `tf.sub` instead. + * + * @param a The first Tensor to subtract element-wise. + * @param b The second Tensor to subtract element-wise. + */ + function subStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'subStrict'); + var $b = convertToTensor(b, 'b', 'subStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in subStrict: '); + return $a.sub($b); + } + /** + * Computes the power of one `tf.Tensor` to another. Supports broadcasting. + * + * Given a `tf.Tensor` x and a `tf.Tensor` y, this operation computes x^y for + * corresponding elements in x and y. The result's dtype will be the upcasted + * type of the `base` and `exp` dtypes. + * + * ```js + * const a = tf.tensor([[2, 3], [4, 5]]) + * const b = tf.tensor([[1, 2], [3, 0]]).toInt(); + * + * a.pow(b).print(); // or tf.pow(a, b) + * ``` + * + * ```js + * const a = tf.tensor([[1, 2], [3, 4]]) + * const b = tf.tensor(2).toInt(); + * + * a.pow(b).print(); // or tf.pow(a, b) + * ``` + * We also expose `powStrict` which has the same signature as this op and + * asserts that `base` and `exp` are the same shape (does not broadcast). + * + * @param base The base `tf.Tensor` to pow element-wise. + * @param exp The exponent `tf.Tensor` to pow element-wise. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function pow_(base, exp) { + var $base = convertToTensor(base, 'base', 'pow'); + var $exp = convertToTensor(exp, 'exp', 'pow'); + var outShape = assertAndGetBroadcastShape($base.shape, $exp.shape); + base = $base.cast(upcastType($base.dtype, $exp.dtype)); + exp = $exp.cast(upcastType($base.dtype, $exp.dtype)); + var grad = function (dy, saved) { + var $base = saved[0], $exp = saved[1], y = saved[2]; + var derBase = function () { + var expFloat = $exp.toFloat(); + var res = dy.mul(expFloat.mul($base.pow(expFloat.sub(scalar(1))))); + var reduceAxes = getReductionAxes($base.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($base.shape); + }; + var derExp = function () { + var condition = $base.greater(0); + var logBase = $base.log().where(condition, zerosLike($base)); + var res = dy.mul(y.mul(logBase)); + var reduceAxes = getReductionAxes($exp.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($exp.shape); + }; + return { $base: derBase, $exp: derExp }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.pow($base, $exp); + save([$base, $exp, y]); + return y; + }, { $base: $base, $exp: $exp }, grad); + } + /** + * Computes the power of one `tf.Tensor` to another. Inputs must + * be the same shape. + * + * For broadcasting support, use `tf.pow` instead. + * + * @param base The base tensor to pow element-wise. + * @param exp The exponent tensor to pow element-wise. + */ + function powStrict_(base, exp) { + assertShapesMatch(base.shape, exp.shape, 'Error in powStrict: '); + return base.pow(exp); + } + /** + * Multiplies two `tf.Tensor`s element-wise, A * B. Supports broadcasting. + * + * We also expose `tf.mulStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3, 4]); + * const b = tf.tensor1d([2, 3, 4, 5]); + * + * a.mul(b).print(); // or tf.mul(a, b) + * ``` + * + * ```js + * // Broadcast mul a with b. + * const a = tf.tensor1d([1, 2, 3, 4]); + * const b = tf.scalar(5); + * + * a.mul(b).print(); // or tf.mul(a, b) + * ``` + * @param a The first tensor to multiply. + * @param b The second tensor to multiply. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function mul_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'mul'); + var $b = convertToTensor(b, 'b', 'mul'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var res = dy.mul($b.toFloat()); + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($a.shape); + } + return res; + }; + var derB = function () { + var res = dy.mul($a.toFloat()); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($b.shape); + } + return res; + }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.multiply($a, $b); + save([$a, $b]); + return res; + }, { a: $a, b: $b }, der, 'Mul'); + } + /** + * Multiplies two `tf.Tensor`s element-wise, A * B. + * + * Inputs must be the same shape. For broadcasting support, use `tf.mul`. + * + * @param a The first tensor to multiply. + * @param b The first tensor to multiply. Must have the same + * dtype as `a`. + */ + function mulStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'mul'); + var $b = convertToTensor(b, 'b', 'mul'); + assertShapesMatch($a.shape, $b.shape, 'Error in multiplyStrict: '); + return $a.mul($b); + } + /** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. + * + * We also expose `tf.divStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * + * a.div(b).print(); // or tf.div(a, b) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function div_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'div'); + var $b = convertToTensor(b, 'b', 'div'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + if ($a.dtype === 'int32' && $b.dtype === 'int32') { + return floorDiv($a, $b); + } + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var res = dy.div($b.toFloat()); + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($a.shape); + } + return res; + }; + var derB = function () { + var res = dy.mul($a.toFloat()); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes).reshape($b.shape); + } + var tmp = $b.square(); + return res.div(tmp.toFloat()).neg(); + }; + return { a: derA, b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.realDivide($a, $b); + save([$a, $b]); + return res; + }, { a: $a, b: $b }, der, 'Div'); + } + /** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. Return 0 + * if denominator is 0. + * + * We also expose `tf.divStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * const c = tf.tensor1d([0, 0, 0, 0]); + * + * a.divNoNan(b).print(); // or tf.divNoNan(a, b) + * a.divNoNan(c).print(); // or tf.divNoNan(a, c) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * const c = tf.scalar(0); + * + * a.divNoNan(b).print(); // or tf.divNoNan(a, b) + * a.divNoNan(c).print(); // or tf.divNoNan(a, c) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function divNoNan_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'div'); + var $b = convertToTensor(b, 'b', 'div'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var divResult = div($a, $b); + var zeros = zerosLike(divResult); + var bEqualsZero = $b.equal(zeros); + return where(bEqualsZero, zeros, divResult); + } + /** + * Divides two `tf.Tensor`s element-wise, A / B. Supports broadcasting. + * The result is rounded with floor function. + * + * + * ```js + * const a = tf.tensor1d([1, 4, 9, 16]); + * const b = tf.tensor1d([1, 2, 3, 4]); + * + * a.floorDiv(b).print(); // or tf.div(a, b) + * ``` + * + * ```js + * // Broadcast div a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(2); + * + * a.floorDiv(b).print(); // or tf.floorDiv(a, b) + * ``` + * + * @param a The first tensor as the numerator. + * @param b The second tensor as the denominator. Must have the same dtype as + * `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function floorDiv_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'floorDiv'); + var $b = convertToTensor(b, 'b', 'floorDiv'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var res = dy.div($b.toFloat()); + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($a.shape); + } + return res; + }; + var derB = function () { + var res = dy.mul($a.toFloat()); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes).reshape($b.shape); + } + var tmp = $b.square(); + return res.div(tmp.toFloat()).neg(); + }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.floorDiv($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + /** + * Divides two `tf.Tensor`s element-wise, A / B. Inputs must + * be the same shape. + * + * @param a The first tensor as the numerator for element-wise division. + * @param b The second tensor as the denominator for element-wise division. + */ + function divStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'div'); + var $b = convertToTensor(b, 'b', 'div'); + assertShapesMatch($a.shape, $b.shape, 'Error in divideStrict: '); + return $a.div($b); + } + /** + * Returns the mod of a and b element-wise. + * `floor(x / y) * y + mod(x, y) = x` + * Supports broadcasting. + * + * We also expose `tf.modStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 3, 16]); + * const b = tf.tensor1d([1, 2, 9, 4]); + * + * a.mod(b).print(); // or tf.mod(a, b) + * ``` + * + * ```js + * // Broadcast a mod b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(5); + * + * a.mod(b).print(); // or tf.mod(a, b) + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function mod_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'mod'); + var $b = convertToTensor(b, 'b', 'mod'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + return dy.sum(reduceAxes).reshape($a.shape); + } + return dy; + }; + var derB = function () { + var res = dy.mul($a.div($b).floor().neg()); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + return res.sum(reduceAxes).reshape($b.shape); + } + return res; + }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.mod($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + /** + * Returns the mod of a and b (`a < b ? a : b`) element-wise. Inputs must + * be the same shape. For broadcasting support, use mod(). + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + */ + function modStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'modStrict'); + var $b = convertToTensor(b, 'b', 'modStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in modStrict: '); + return $a.mod($b); + } + /** + * Returns the min of a and b (`a < b ? a : b`) element-wise. + * Supports broadcasting. + * + * We also expose `minimumStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 3, 16]); + * const b = tf.tensor1d([1, 2, 9, 4]); + * + * a.minimum(b).print(); // or tf.minimum(a, b) + * ``` + * + * ```js + * // Broadcast minimum a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(5); + * + * a.minimum(b).print(); // or tf.minimum(a, b) + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function minimum_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'minimum'); + var $b = convertToTensor(b, 'b', 'minimum'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + if ($a.dtype === 'bool') { + $a = $a.toInt(); + $b = $b.toInt(); + } + assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { return dy.mul($a.lessEqual($b).toFloat()); }; + var derB = function () { return dy.mul($a.greater($b).toFloat()); }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.minimum($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + /** + * Returns the min of a and b (`a < b ? a : b`) element-wise. Inputs must + * be the same shape. For broadcasting support, use minimum(). + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + */ + function minimumStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'minimumStrict'); + var $b = convertToTensor(b, 'b', 'minimumStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in minimumStrict: '); + return $a.minimum($b); + } + /** + * Returns the max of a and b (`a > b ? a : b`) element-wise. + * Supports broadcasting. + * + * We also expose `tf.maximumStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 3, 16]); + * const b = tf.tensor1d([1, 2, 9, 4]); + * + * a.maximum(b).print(); // or tf.maximum(a, b) + * ``` + * + * ```js + * // Broadcast maximum a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(5); + * + * a.maximum(b).print(); // or tf.maximum(a, b) + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function maximum_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'maximum'); + var $b = convertToTensor(b, 'b', 'maximum'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + if ($a.dtype === 'bool') { + $a = $a.toInt(); + $b = $b.toInt(); + } + assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { return dy.mul($a.greaterEqual($b).toFloat()); }; + var derB = function () { return dy.mul($a.less($b).toFloat()); }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.maximum($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + /** + * Returns the max of a and b (`a > b ? a : b`) element-wise. Inputs must + * be the same shape. For broadcasting support, use maximum(). + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + */ + function maximumStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'maximumStrict'); + var $b = convertToTensor(b, 'b', 'maximumStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in maximumStrict: '); + return $a.maximum($b); + } + /** + * Returns (a - b) * (a - b) element-wise. + * Supports broadcasting. + * + * We also expose `tf.squaredDifferenceStrict` which has the same signature as + * this op and asserts that `a` and `b` are the same shape (does not + * broadcast). + * + * ```js + * const a = tf.tensor1d([1, 4, 3, 16]); + * const b = tf.tensor1d([1, 2, 9, 4]); + * + * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b) + * ``` + * + * ```js + * // Broadcast squared difference a with b. + * const a = tf.tensor1d([2, 4, 6, 8]); + * const b = tf.scalar(5); + * + * a.squaredDifference(b).print(); // or tf.squaredDifference(a, b) + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Arithmetic'} */ + function squaredDifference_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'squaredDifference'); + var $b = convertToTensor(b, 'b', 'squaredDifference'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var two = scalar(2); + var derA = function () { return dy.mul($a.sub($b).mul(two)); }; + var derB = function () { return dy.mul($b.sub($a).mul(two)); }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.squaredDifference($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + /** + * Returns (a - b) * (a - b) element-wise. + * + * Inputs must be the same shape. For broadcasting support, use + * `tf.squaredDifference` instead. + * + * @param a The first tensor. + * @param b The second tensor. Must have the same type as `a`. + */ + function squaredDifferenceStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'squaredDifferenceStrict'); + var $b = convertToTensor(b, 'b', 'squaredDifferenceStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in squaredDifferenceStrict: '); + return $a.squaredDifference($b); + } + /** + * Computes arctangent of `tf.Tensor`s a / b element-wise: `atan2(a, b)`. + * Supports broadcasting. + * + * ```js + * const a = tf.tensor1d([1.0, 1.0, -1.0, .7]); + * const b = tf.tensor1d([2.0, 13.0, 3.5, .21]); + * + * tf.atan2(a, b).print() + * ``` + * + * @param a The first tensor. + * @param b The second tensor. Must have the same dtype as `a`. + * + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function atan2_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'atan2'); + var $b = convertToTensor(b, 'b', 'atan2'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var outShape = assertAndGetBroadcastShape($a.shape, $b.shape); + var der = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + var derA = function () { + var d = add($a.square(), $b.square()); + var res = dy.mul($b.div(d)); + var reduceAxes = getReductionAxes($a.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($a.shape); + }; + var derB = function () { + var d = add($a.square(), $b.square()); + var res = neg(dy.mul($a.div(d))); + var reduceAxes = getReductionAxes($b.shape, outShape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($b.shape); + }; + return { $a: derA, $b: derB }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.atan2($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, der); + } + var add = op({ add_: add_ }); + var addN = op({ addN_: addN_ }); + var addStrict = op({ addStrict_: addStrict_ }); + var atan2 = op({ atan2_: atan2_ }); + var div = op({ div_: div_ }); + var divNoNan = op({ divNoNan_: divNoNan_ }); + var divStrict = op({ divStrict_: divStrict_ }); + var floorDiv = op({ floorDiv_: floorDiv_ }); + var maximum = op({ maximum_: maximum_ }); + var maximumStrict = op({ maximumStrict_: maximumStrict_ }); + var minimum = op({ minimum_: minimum_ }); + var minimumStrict = op({ minimumStrict_: minimumStrict_ }); + var mod = op({ mod_: mod_ }); + var modStrict = op({ modStrict_: modStrict_ }); + var mul = op({ mul_: mul_ }); + var mulStrict = op({ mulStrict_: mulStrict_ }); + var pow = op({ pow_: pow_ }); + var powStrict = op({ powStrict_: powStrict_ }); + var squaredDifference = op({ squaredDifference_: squaredDifference_ }); + var squaredDifferenceStrict = op({ squaredDifferenceStrict_: squaredDifferenceStrict_ }); + var sub = op({ sub_: sub_ }); + var subStrict = op({ subStrict_: subStrict_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns the truth value of (a != b) element-wise. Supports broadcasting. + * + * We also expose `tf.notEqualStrict` which has the same signature as this op + * and asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([0, 2, 3]); + * + * a.notEqual(b).print(); + * ``` + * @param a The first input tensor. + * @param b The second input tensor. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function notEqual_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'notEqual'); + var $b = convertToTensor(b, 'b', 'notEqual'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assertAndGetBroadcastShape($a.shape, $b.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.notEqual($a, $b); }, { $a: $a, $b: $b }); + } + /** + * Strict version of `tf.notEqual` that forces `a` and `b` to be of the same + * shape. + * + * @param a The first input tensor. + * @param b The second input tensor. Must have the same shape and dtype as + * `a`. + */ + function notEqualStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'notEqualStrict'); + var $b = convertToTensor(b, 'b', 'notEqualStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in notEqualStrict: '); + return $a.notEqual($b); + } + /** + * Returns the truth value of (a < b) element-wise. Supports broadcasting. + * + * We also expose `tf.lessStrict` which has the same signature as this op and + * asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([2, 2, 2]); + * + * a.less(b).print(); + * ``` + * @param a The first input tensor. + * @param b The second input tensor. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function less_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'less'); + var $b = convertToTensor(b, 'b', 'less'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assertAndGetBroadcastShape($a.shape, $b.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.less($a, $b); }, { $a: $a, $b: $b }); + } + /** + * Strict version of `tf.less` that forces `a` and `b` to be of the same + * shape. + * + * @param a The first input tensor. + * @param b The second input tensor. Must have the same shape and dtype as + * `a`. + */ + function lessStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'lessStrict'); + var $b = convertToTensor(b, 'b', 'lessStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in lessStrict: '); + return $a.less($b); + } + /** + * Returns the truth value of (a == b) element-wise. Supports broadcasting. + * + * We also expose `tf.equalStrict` which has the same signature as this op + * and asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([2, 2, 2]); + * + * a.equal(b).print(); + * ``` + * + * @param a The first input tensor. + * @param b The second input tensor. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function equal_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'equal'); + var $b = convertToTensor(b, 'b', 'equal'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assertAndGetBroadcastShape($a.shape, $b.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.equal($a, $b); }, { $a: $a, $b: $b }); + } + function equalStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'equalStrict'); + var $b = convertToTensor(b, 'b', 'equalStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in equalStrict: '); + return $a.equal($b); + } + /** + * Returns the truth value of (a <= b) element-wise. Supports broadcasting. + * + * We also expose `tf.lessEqualStrict` which has the same signature as this op + * and asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([2, 2, 2]); + * + * a.lessEqual(b).print(); + * ``` + * + * @param a The first input tensor. + * @param b The second input tensor. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function lessEqual_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'lessEqual'); + var $b = convertToTensor(b, 'b', 'lessEqual'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assertAndGetBroadcastShape($a.shape, $b.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.lessEqual($a, $b); }, { $a: $a, $b: $b }); + } + function lessEqualStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'lessEqualStrict'); + var $b = convertToTensor(b, 'b', 'lessEqualStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in lessEqualStrict: '); + return $a.lessEqual($b); + } + /** + * Returns the truth value of (a > b) element-wise. Supports broadcasting. + * + * We also expose `tf.greaterStrict` which has the same signature as this + * op and asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([2, 2, 2]); + * + * a.greater(b).print(); + * ``` + * + * @param a The first input tensor. + * @param b The second input tensor. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function greater_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'greater'); + var $b = convertToTensor(b, 'b', 'greater'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assertAndGetBroadcastShape($a.shape, $b.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.greater($a, $b); }, { $a: $a, $b: $b }); + } + function greaterStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'greaterStrict'); + var $b = convertToTensor(b, 'b', 'greaterStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in greaterStrict: '); + return $a.greater($b); + } + /** + * Returns the truth value of (a >= b) element-wise. Supports broadcasting. + * + * We also expose `tf.greaterEqualStrict` which has the same signature as this + * op and asserts that `a` and `b` are the same shape (does not broadcast). + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([2, 2, 2]); + * + * a.greaterEqual(b).print(); + * ``` + * + * @param a The first input tensor. + * @param b The second input tensor. Must have the same dtype as `a`. + */ + /** @doc {heading: 'Operations', subheading: 'Logical'} */ + function greaterEqual_(a, b) { + var _a; + var $a = convertToTensor(a, 'a', 'greaterEqual'); + var $b = convertToTensor(b, 'b', 'greaterEqual'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + assertAndGetBroadcastShape($a.shape, $b.shape); + var grad = function (dy, saved) { + var $a = saved[0], $b = saved[1]; + return { $a: function () { return zerosLike($a); }, $b: function () { return zerosLike($b); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.greaterEqual($a, $b); + save([$a, $b]); + return res; + }, { $a: $a, $b: $b }, grad); + } + function greaterEqualStrict_(a, b) { + var $a = convertToTensor(a, 'a', 'greaterEqualStrict'); + var $b = convertToTensor(b, 'b', 'greaterEqualStrict'); + assertShapesMatch($a.shape, $b.shape, 'Error in greaterEqualStrict: '); + return $a.greaterEqual($b); + } + var equal = op({ equal_: equal_ }); + var equalStrict = op({ equalStrict_: equalStrict_ }); + var greater = op({ greater_: greater_ }); + var greaterEqual = op({ greaterEqual_: greaterEqual_ }); + var greaterEqualStrict = op({ greaterEqualStrict_: greaterEqualStrict_ }); + var greaterStrict = op({ greaterStrict_: greaterStrict_ }); + var less = op({ less_: less_ }); + var lessEqual = op({ lessEqual_: lessEqual_ }); + var lessEqualStrict = op({ lessEqualStrict_: lessEqualStrict_ }); + var lessStrict = op({ lessStrict_: lessStrict_ }); + var notEqual = op({ notEqual_: notEqual_ }); + var notEqualStrict = op({ notEqualStrict_: notEqualStrict_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the sum along segments of a `tf.Tensor`. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * const segmentIds = tf.tensor1d([1, 2, 0, 1], 'int32'); + * const numSegments = 3; + * + * x.unsortedSegmentSum(segmentIds, numSegments).print() + * //or tf.unsortedSegmentSum(x, segmentIds, numSegments) + * ``` + * @param x The `tf.Tensor` that will be summed along its segments. + * @param segmentIds A `tf.Tensor1D` whose rank is equal to the rank of `x`'s + * dimension along the `axis`. Maps each element of `x` to a segment. + * @param numSegments The number of distinct `segmentIds`. + */ + /** @doc {heading: 'Operations', subheading: 'Segment'} */ + function unsortedSegmentSum_(x, segmentIds, numSegments) { + var $x = convertToTensor(x, 'x', 'unsortedSegmentSum'); + var $segmentIds = convertToTensor(segmentIds, 'segmentIds', 'unsortedSegmentSum', 'int32'); + assert(isInt(numSegments), function () { return 'numSegments must be of dtype int'; }); + var gradFunc = function (dy, saved) { + var $segmentIds = saved[0]; + var derX = function () { + return gatherDropNegatives(dy, $segmentIds); + }; + return { $x: derX }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.unsortedSegmentSum($x, $segmentIds, numSegments); + save([$segmentIds]); + return res; + }, { $x: $x }, gradFunc); + } + /** + * Gather slices from tensor `x`'s axis `axis` according to `indices`. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * const indices = tf.tensor1d([1, 3, 3], 'int32'); + * + * x.gather(indices).print(); + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * const indices = tf.tensor1d([1, 1, 0], 'int32'); + * + * x.gather(indices).print(); + * ``` + * @param x The input tensor whose slices to be gathered. + * @param indices The indices of the values to extract. + * @param axis The axis over which to select values. Defaults to 0. + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function gather_(x, indices, axis) { + if (axis === void 0) { axis = 0; } + var $x = convertToTensor(x, 'x', 'gather'); + var $indices = convertToTensor(indices, 'indices', 'gather', 'int32'); + axis = parseAxisParam(axis, $x.shape)[0]; + var shapeInfo = collectGatherOpShapeInfo($x, $indices, axis); + var grad = function (dy, saved) { + var $indices = saved[0]; + var derX = function () { + var paramsShape = $x.shape; + var indicesSize = $indices.size; + var outerShape = paramsShape.slice(0, axis); + var outerDims = outerShape.length; + var innerShape = paramsShape.slice(axis, paramsShape.length).slice(1); + var innerDims = innerShape.length; + var outerAxesIndices = arrayRange(0, outerDims); + var innerAxesIndices = arrayRange(outerDims + 1, outerDims + 1 + innerDims); + var valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]); + var values = dy.reshape(valuesShape); + var reshapedIndices = $indices.reshape([indicesSize]); + var transposeDims = arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]); + var valuesTranspose = values.transpose(transposeDims); + var paramsGrad = unsortedSegmentSum(valuesTranspose, reshapedIndices, $x.shape[axis]); + var invertTransposeDims = getUndoAxesPermutation(transposeDims); + paramsGrad = paramsGrad.transpose(invertTransposeDims); + return paramsGrad; + }; + return { $x: derX }; + }; + return (ENGINE.runKernelFunc(function (backend, save) { + var res = backend.gather($x, $indices.flatten(), axis); + save([$indices]); + return res; + }, { $x: $x }, grad)).reshape(shapeInfo.outputShape); + } + function arrayRange(start, stop) { + var result = []; + for (var i = start; i < stop; ++i) { + result.push(i); + } + return result; + } + function arrayConcat(arrays) { + var result = []; + for (var i = 0; i < arrays.length; ++i) { + for (var j = 0; j < arrays[i].length; ++j) { + result.push(arrays[i][j]); + } + } + return result; + } + function gatherDropNegatives(x, indices) { + // Helper function for unsorted segment ops. Gathers params for + // positive segment ids and gathers 0 for inputs with negative segment id. + // Mirrors _GatherDropNegatives from tensorflow/python/ops/math_grad.py + var zeroClippedIndices = maximum(indices, zerosLike(indices)); + var gathered = gather(x, zeroClippedIndices); + var isPositive = greaterEqual(indices, scalar(0, 'int32')); + var numIters = gathered.rank - isPositive.rank; + for (var i = 0; i < numIters; ++i) { + isPositive = expandDims(isPositive, i + 1); + } + isPositive = logicalAnd(isPositive, ones$1(gathered.shape, 'bool')); + var zeroSlice = zerosLike(gathered); + return where(isPositive, gathered, zeroSlice); + } + var gather = op({ gather_: gather_ }); + var unsortedSegmentSum = op({ unsortedSegmentSum_: unsortedSegmentSum_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Apply boolean mask to tensor. + * + * ```js + * const tensor = tf.tensor2d([1, 2, 3, 4, 5, 6], [3, 2]); + * const mask = tf.tensor1d([1, 0, 1], 'bool'); + * const result = await tf.booleanMaskAsync(tensor, mask); + * result.print(); + * ``` + * + * @param tensor N-D tensor. + * @param mask K-D boolean tensor, K <= N and K must be known statically. + * @param axis A 0-D int Tensor representing the axis in tensor to mask from. + * By default, axis is 0 which will mask from the first dimension. + * Otherwise K + axis <= N. + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function booleanMaskAsync_(tensor, mask, axis) { + return __awaiter(this, void 0, void 0, function () { + var $tensor, $mask, axisFrom, maskDim, tensorShape, leadingSize, i, targetTensorShape, reshapedTensor, reshapedMask, positivePositions, indices, res; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + $tensor = convertToTensor(tensor, 'tensor', 'boolMask'); + $mask = convertToTensor(mask, 'mask', 'boolMask', 'bool'); + axisFrom = axis == null ? 0 : axis; + maskDim = $mask.rank; + tensorShape = $tensor.shape; + assert(maskDim > 0, function () { return 'mask cannot be scalar'; }); + assertShapesMatch(tensorShape.slice(axisFrom, axisFrom + maskDim), $mask.shape, "mask's shape must match the first K dimensions of tensor's shape,"); + leadingSize = 1; + for (i = axisFrom; i < axisFrom + maskDim; i++) { + leadingSize *= tensorShape[i]; + } + targetTensorShape = tensorShape.slice(0, axisFrom) + .concat([leadingSize], tensorShape.slice(axisFrom + maskDim)); + reshapedTensor = $tensor.reshape(targetTensorShape); + reshapedMask = $mask.reshape([-1]); + return [4 /*yield*/, whereAsync(reshapedMask)]; + case 1: + positivePositions = _a.sent(); + indices = positivePositions.squeeze([1]); + res = gather(reshapedTensor, indices, axisFrom); + // Ensure no memory leak. + if (tensor !== $tensor) { + $tensor.dispose(); + } + if (mask !== $mask) { + $mask.dispose(); + } + indices.dispose(); + reshapedTensor.dispose(); + reshapedMask.dispose(); + positivePositions.dispose(); + return [2 /*return*/, res]; + } + }); + }); + } + var booleanMaskAsync = booleanMaskAsync_; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes a 1D convolution over the input x. + * + * @param x The input tensor, of rank 3 or rank 2, of shape + * `[batch, width, inChannels]`. If rank 2, batch of 1 is assumed. + * @param filter The filter, rank 3, of shape + * `[filterWidth, inDepth, outDepth]`. + * @param stride The number of entries by which the filter is moved right at + * each step. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC", + * the data is stored in the order of [batch, in_width, in_channels]. Only + * "NWC" is currently supported. + * @param dilation The dilation rate in which we sample input values in + * atrous convolution. Defaults to `1`. If it is greater than 1, then + * stride must be `1`. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function conv1d_(x, filter, stride, pad, dataFormat, dilation, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NWC'; } + if (dilation === void 0) { dilation = 1; } + var $x = convertToTensor(x, 'x', 'conv1d'); + var $filter = convertToTensor(filter, 'filter', 'conv1d'); + var x3D = $x; + var reshapedTo3D = false; + if ($x.rank === 2) { + reshapedTo3D = true; + x3D = $x.as3D(1, $x.shape[0], $x.shape[1]); + } + assert(x3D.rank === 3, function () { return "Error in conv1d: input must be rank 3, but got rank " + x3D.rank + "."; }); + assert($filter.rank === 3, function () { return "Error in conv1d: filter must be rank 3, but got rank " + + ($filter.rank + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in conv1d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + assert(x3D.shape[2] === $filter.shape[1], function () { return "Error in conv1d: depth of input (" + x3D.shape[2] + ") must match " + + ("input depth for filter " + $filter.shape[1] + "."); }); + assert(eitherStridesOrDilationsAreOne(stride, dilation), function () { return 'Error in conv1D: Either stride or dilation must be 1. ' + + ("Got stride " + stride + " and dilation '" + dilation + "'"); }); + assert(dataFormat === 'NWC', function () { return "Error in conv1d: got dataFormat of " + dataFormat + " but only NWC is currently supported."; }); + var filter4D = $filter.as4D(1, $filter.shape[0], $filter.shape[1], $filter.shape[2]); + var input4D = x3D.as4D(x3D.shape[0], 1, x3D.shape[1], x3D.shape[2]); + var strides = [1, stride]; + var dilations = [1, dilation]; + var conv2dDataFormat = 'NHWC'; + var res = conv2d(input4D, filter4D, strides, pad, conv2dDataFormat, dilations, dimRoundingMode); + if (reshapedTo3D) { + return res.as2D(res.shape[2], res.shape[3]); + } + return res.as3D(res.shape[0], res.shape[2], res.shape[3]); + } + /** + * Computes a 2D convolution over the input x. + * + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function conv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + if (dilations === void 0) { dilations = [1, 1]; } + var $x = convertToTensor(x, 'x', 'conv2d'); + var $filter = convertToTensor(filter, 'filter', 'conv2d'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + assert(x4D.rank === 4, function () { return "Error in conv2d: input must be rank 4, but got rank " + x4D.rank + "."; }); + assert($filter.rank === 4, function () { return "Error in conv2d: filter must be rank 4, but got rank " + + ($filter.rank + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in conv2d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; + assert(inDepth === $filter.shape[2], function () { return "Error in conv2d: depth of input (" + inDepth + ") must match " + + ("input depth for filter " + $filter.shape[2] + "."); }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in conv2D: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + var $dataFormat = convertConv2DDataFormat(dataFormat); + var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat); + var grad = function (dy, saved) { + var _a = saved, $filter = _a[0], x4D = _a[1]; + assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of conv2D: dilation rates greater than 1 ' + + ("are not yet supported in gradients. Got dilations '" + dilations + "'"); }); + return { + x: function () { return conv2dDerInput(x4D.shape, dy, $filter, strides, pad, dataFormat); }, + filter: function () { + return conv2dDerFilter(x4D, dy, $filter.shape, strides, pad, dataFormat); + } + }; + }; + var inputsToSave = [$filter, x4D]; + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.conv2d(x4D, $filter, convInfo); + save([$filter, x4D]); + return res; + }, { x: x4D, filter: $filter }, grad, 'Conv2D', convInfo, inputsToSave); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * Computes the derivative of the input of a 2D convolution. + * + * @param xShape The shape of the input: [batch, height, width, inDepth]. + * If length of 3, batch of 1 is assumed. + * @param dy The derivative of the output, of rank 4 or rank 3 of shape + * `[batch, outHeight, outWidth, outDepth]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm used: + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + function conv2dDerInput_(xShape, dy, filter, strides, pad, dataFormat, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + assert(xShape.length === dy.rank, function () { return "Length of inShape " + + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match"); }); + var xShape4D = xShape; + var dy4D = dy; + var reshapedTo4D = false; + if (dy.rank === 3) { + reshapedTo4D = true; + dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]); + xShape4D = [1, xShape[0], xShape[1], xShape[2]]; + } + assert(xShape4D.length === 4, function () { + return "Error in conv2dDerInput: inShape must be length 4, but got length " + + (xShape4D.length + "."); + }); + assert(dy4D.rank === 4, function () { return "Error in conv2dDerInput: dy must be rank 4, but got " + + ("rank " + dy4D.rank); }); + assert(filter.rank === 4, function () { return "Error in conv2dDerInput: filter must be rank 4, but got " + + ("rank " + filter.rank); }); + var inDepth = dataFormat === 'NHWC' ? xShape4D[3] : xShape4D[1]; + var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; + assert(inDepth === filter.shape[2], function () { return "Error in conv2dDerInput: depth of input (" + inDepth + ") must " + + ("match input depth for filter " + filter.shape[2] + "."); }); + assert(outDepth === filter.shape[3], function () { return "Error in conv2dDerInput: depth of output (" + outDepth + ") must " + + ("match output depth for filter " + filter.shape[3] + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in conv2dDerInput: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var dilations = 1; + var grad = function (ddx, saved) { + var filter = saved[0], dy4D = saved[1]; + return { + dy4D: function () { return conv2d(ddx, filter, strides, pad, dataFormat, dilations, dimRoundingMode); }, + filter: function () { return conv2dDerFilter(ddx, dy4D, filter.shape, strides, pad, dataFormat, dimRoundingMode); } + }; + }; + var $dataFormat = convertConv2DDataFormat(dataFormat); + var convInfo = computeConv2DInfo(xShape4D, filter.shape, strides, dilations, pad, dimRoundingMode, false, $dataFormat); + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.conv2dDerInput(dy4D, filter, convInfo); + save([filter, dy4D]); + return res; + }, { dy4D: dy4D, filter: filter }, grad); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * Computes the derivative of the filter of a 2D convolution. + * + * @param x The input tensor, of rank 4 or rank 3 of shape + * [batch, height, width, inChannels]. If rank 3, batch of 1 is assumed. + * @param dy The dy image, of rank 4 or rank 3, of shape + * [batch, height, width, outDepth]. If rank 3, batch of 1 is assumed. + * @param filterShape The shape of the filter, length 4, + * [filterHeight, filterWidth, inDepth, outDepth]. + * @param strides The strides of the convolution: [strideHeight, + * strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function conv2dDerFilter_(x, dy, filterShape, strides, pad, dataFormat, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + var x4D = x; + if (x.rank === 3) { + x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]); + } + var dy4D = dy; + if (dy4D.rank === 3) { + dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]); + } + assert(x4D.rank === 4, function () { return "Error in conv2dDerFilter: input must be rank 4, but got shape " + + (x4D.shape + "."); }); + assert(dy4D.rank === 4, function () { return "Error in conv2dDerFilter: dy must be rank 4, but got shape " + + (dy4D.shape + "."); }); + assert(filterShape.length === 4, function () { return "Error in conv2dDerFilter: filterShape must be length 4, but got " + + (filterShape + "."); }); + var inDepth = dataFormat === 'NHWC' ? x4D.shape[3] : x4D.shape[1]; + var outDepth = dataFormat === 'NHWC' ? dy4D.shape[3] : dy4D.shape[1]; + assert(inDepth === filterShape[2], function () { return "Error in conv2dDerFilter: depth of input " + inDepth + ") must " + + ("match input depth in filter (" + filterShape[2] + "."); }); + assert(outDepth === filterShape[3], function () { return "Error in conv2dDerFilter: depth of dy (" + outDepth + ") must " + + ("match output depth for filter (" + filterShape[3] + ")."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in conv2dDerFilter: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var dilations = 1; + var $dataFormat = convertConv2DDataFormat(dataFormat); + var convInfo = computeConv2DInfo(x4D.shape, filterShape, strides, dilations, pad, dimRoundingMode, false, $dataFormat); + return ENGINE.runKernelFunc(function (backend) { return backend.conv2dDerFilter(x4D, dy4D, convInfo); }, { x4D: x4D, dy4D: dy4D }); + } + /** + * Computes the transposed 2D convolution of an image, also known as a + * deconvolution. + * + * @param x The input image, of rank 4 or rank 3, of shape + * `[batch, height, width, inDepth]`. If rank 3, batch of 1 is assumed. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, outDepth, inDepth]`. + * `inDepth` must match `inDepth` in `x`. + * @param outputShape Output shape, of rank 4 or rank 3: + * `[batch, height, width, outDepth]`. If rank 3, batch of 1 is assumed. + * @param strides The strides of the original convolution: + * `[strideHeight, strideWidth]`. + * @param pad The type of padding algorithm used in the non-transpose version + * of the op. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function conv2dTranspose_(x, filter, outputShape, strides, pad, dimRoundingMode) { + var $x = convertToTensor(x, 'x', 'conv2dTranspose'); + var $filter = convertToTensor(filter, 'filter', 'conv2dTranspose'); + return conv2dDerInput_(outputShape, $x, $filter, strides, pad, 'NHWC', dimRoundingMode); + } + /** + * Depthwise 2D convolution. + * + * Given a 4D `input` array and a `filter` array of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing + * `inChannels` convolutional filters of depth 1, this op applies a + * different filter to each input channel (expanding from 1 channel to + * `channelMultiplier` channels for each), then concatenates the results + * together. The output has `inChannels * channelMultiplier` channels. + * + * See + * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d]( + * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d) + * for more details. + * + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter tensor, rank 4, of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. If strides is a single number, then `strideHeight == + * strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. Only "NHWC" is currently supported. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function depthwiseConv2d_(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode) { + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + if (dilations === void 0) { dilations = [1, 1]; } + var $x = convertToTensor(x, 'x', 'depthwiseConv2d'); + var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + assert(x4D.rank === 4, function () { return "Error in depthwiseConv2d: input must be rank 4, but got " + + ("rank " + x4D.rank + "."); }); + assert($filter.rank === 4, function () { return "Error in depthwiseConv2d: filter must be rank 4, but got rank " + + ($filter.rank + "."); }); + assert(x4D.shape[3] === $filter.shape[2], function () { return "Error in depthwiseConv2d: number of input channels " + + ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") + + ("filter " + $filter.shape[2] + "."); }); + if (dilations == null) { + dilations = [1, 1]; + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { + return 'Error in depthwiseConv2d: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); + }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in depthwiseConv2d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */); + var grad = function (dy, saved) { + assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of depthwiseConv2d: dilation rates ' + + "greater than 1 are not yet supported. Got dilations " + + ("'" + dilations + "'"); }); + var x4D = saved[0], $filter = saved[1]; + return { + x: function () { return depthwiseConv2dDerInput(x4D.shape, dy, $filter, convInfo); }, + filter: function () { return depthwiseConv2dDerFilter(x4D, dy, $filter.shape, convInfo); }, + }; + }; + var inputsToSave = [x4D, $filter]; + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.depthwiseConv2D(x4D, $filter, convInfo); + save([x4D, $filter]); + return res; + }, { x: x4D, filter: $filter }, grad, 'DepthwiseConv2dNative', convInfo, inputsToSave); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * 2-D convolution with separable filters. + * + * Performs a depthwise convolution that acts separately on channels followed + * by a pointwise convolution that mixes channels. Note that this is + * separability between dimensions [1, 2] and 3, not spatial separability + * between dimensions 1 and 2. + * + * See + * [https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d]( + * https://www.tensorflow.org/api_docs/python/tf/nn/separable_conv2d) + * for more details. + * + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param depthwiseFilter The depthwise filter tensor, rank 4, of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. This is + * the filter used in the first step. + * @param pointwiseFilter The pointwise filter tensor, rank 4, of shape + * `[1, 1, inChannels * channelMultiplier, outChannels]`. This is + * the filter used in the second step. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. If strides is a single number, then `strideHeight == + * strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. Only "NHWC" is currently supported. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function separableConv2d_(x, depthwiseFilter, pointwiseFilter, strides, pad, dilation, dataFormat) { + if (dilation === void 0) { dilation = [1, 1]; } + if (dataFormat === void 0) { dataFormat = 'NHWC'; } + var $x = convertToTensor(x, 'x', 'separableConv2d'); + var $depthwiseFilter = convertToTensor(depthwiseFilter, 'depthwiseFilter', 'separableConv2d'); + var $pointwiseFilter = convertToTensor(pointwiseFilter, 'pointwiseFilter', 'separableConv2d'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + if (dataFormat === 'NCHW') { + throw new Error('separableConv2d currently does not support dataFormat NCHW; only ' + + 'NHWC is supported'); + } + assert(x4D.rank === 4, function () { return "Error in separableConv2d: input must be rank 4, but got " + + ("rank " + x4D.rank + "."); }); + assert($depthwiseFilter.rank === 4, function () { return "Error in separableConv2d: depthwise filter must be rank 4, but " + + ("got rank " + $depthwiseFilter.rank + "."); }); + assert($pointwiseFilter.rank === 4, function () { return "Error in separableConv2d: pointwise filter must be rank 4, but " + + ("got rank " + $depthwiseFilter.rank + "."); }); + assert($pointwiseFilter.shape[0] === 1, function () { + return "Error in separableConv2d: the first dimension of pointwise filter " + + (" must be 1, but got " + $pointwiseFilter.shape[0] + "."); + }); + assert($pointwiseFilter.shape[1] === 1, function () { return "Error in separableConv2d: the second dimension of pointwise " + + ("filter must be 1, but got " + $pointwiseFilter.shape[1] + "."); }); + var inChannels = $depthwiseFilter.shape[2]; + var channelMultiplier = $depthwiseFilter.shape[3]; + assert($pointwiseFilter.shape[2] === inChannels * channelMultiplier, function () { + return "Error in separableConv2d: the third dimension of pointwise filter " + + ("must be " + inChannels * channelMultiplier + ", ") + + ("but got " + $pointwiseFilter.shape[2] + "."); + }); + var depthwise = depthwiseConv2d(x4D, $depthwiseFilter, strides, pad, dataFormat, dilation); + var pointwiseStride = 1; + var res = conv2d(depthwise, $pointwiseFilter, pointwiseStride, 'valid', dataFormat); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + function parseTupleParam$1(param) { + if (typeof param === 'number') { + return [param, param, param]; + } + if (param.length === 2) { + return [param[0], param[1], 1]; + } + return param; + } + function tupleValuesAreOne$1(param) { + var _a = parseTupleParam$1(param), dimA = _a[0], dimB = _a[1], dimC = _a[2]; + return dimA === 1 && dimB === 1 && dimC === 1; + } + function eitherStridesOrDilationsAreOne$1(strides, dilations) { + return tupleValuesAreOne$1(strides) || tupleValuesAreOne$1(dilations); + } + function depthwiseConv2dDerInput_(xShape, dy, filter, convInfo) { + var dy4D = dy; + var reshapedTo4D = false; + if (dy.rank === 3) { + reshapedTo4D = true; + dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]); + } + var res = ENGINE.runKernelFunc(function (backend) { return backend.depthwiseConv2DDerInput(dy4D, filter, convInfo); }, { dy4D: dy4D }); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + function depthwiseConv2dDerFilter_(x, dy, filterShape, convInfo) { + var x4D = x; + if (x.rank === 3) { + x4D = x.as4D(1, x.shape[0], x.shape[1], x.shape[2]); + } + var dy4D = dy; + if (dy4D.rank === 3) { + dy4D = dy.as4D(1, dy.shape[0], dy.shape[1], dy.shape[2]); + } + return ENGINE.runKernelFunc(function (backend) { return backend.depthwiseConv2DDerFilter(x4D, dy4D, convInfo); }, { x4D: x4D, dy4D: dy4D }); + } + /** + * Computes a 3D convolution over the input x. + * + * @param x The input tensor, of rank 5 or rank 4, of shape + * `[batch, depth, height, width, channels]`. If rank 4, + * batch of 1 is assumed. + * @param filter The filter, rank 5, of shape + * `[filterDepth, filterHeight, filterWidth, inChannels, outChannels]`. + * inChannels must match between input and filter. + * @param strides The strides of the convolution: `[strideDepth, strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat: An optional string from: "NDHWC", "NCDHW". Defaults to + * "NDHWC". Specify the data format of the input and output data. With the + * default format "NDHWC", the data is stored in the order of: [batch, + * depth, height, width, channels]. Only "NDHWC" is currently supported. + * @param dilations The dilation rates: `[dilationDepth, dilationHeight, + * dilationWidth]` in which we sample input values across the height + * and width dimensions in atrous convolution. Defaults to `[1, 1, 1]`. + * If `dilations` is a single number, then + * `dilationDepth == dilationHeight == dilationWidth`. If it is greater + * than 1, then all values of `strides` must be 1. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function conv3d_(x, filter, strides, pad, dataFormat, dilations) { + if (dataFormat === void 0) { dataFormat = 'NDHWC'; } + if (dilations === void 0) { dilations = [1, 1, 1]; } + var $x = convertToTensor(x, 'x', 'conv3d'); + var $filter = convertToTensor(filter, 'filter', 'conv3d'); + var x5D = $x; + var reshapedTo5D = false; + if ($x.rank === 4) { + reshapedTo5D = true; + x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]); + } + assert(x5D.rank === 5, function () { return "Error in conv3d: input must be rank 5, but got rank " + x5D.rank + "."; }); + assert($filter.rank === 5, function () { return "Error in conv3d: filter must be rank 5, but got rank " + + ($filter.rank + "."); }); + assert(x5D.shape[4] === $filter.shape[3], function () { return "Error in conv3d: depth of input (" + x5D.shape[4] + ") must match " + + ("input depth for filter " + $filter.shape[3] + "."); }); + assert(eitherStridesOrDilationsAreOne$1(strides, dilations), function () { return 'Error in conv3D: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + assert(dataFormat === 'NDHWC', function () { return "Error in conv3d: got dataFormat of " + dataFormat + " but only NDHWC is currently supported."; }); + var convInfo = computeConv3DInfo(x5D.shape, $filter.shape, strides, dilations, pad); + var grad = function (dy, saved) { + assert(tupleValuesAreOne$1(dilations), function () { + return 'Error in gradient of conv3D: dilation rates greater than 1 are ' + + ("not yet supported in gradients. Got dilations '" + dilations + "'"); + }); + var x5D = saved[0], $filter = saved[1]; + return { + x: function () { return conv3dDerInput_(x5D.shape, dy, $filter, strides, pad); }, + $filter: function () { return conv3dDerFilter_(x5D, dy, $filter.shape, strides, pad); } + }; + }; + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.conv3d(x5D, $filter, convInfo); + save([x5D, $filter]); + return res; + }, { x: x5D, $filter: $filter }, grad); + if (reshapedTo5D) { + return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]); + } + return res; + } + /** + * Computes the derivative of the input of a 3D convolution. + * + * @param xShape The shape of the input: [batch, depth, height, width, + * in_channels]. If length of 4, batch of 1 is assumed. + * @param dy The derivative of the output, of rank 5 or rank 4 of shape + * `[batch, outDepth, outHeight, outWidth, in_channels]`. + * If rank 4, batch of 1 is assumed. + * @param filter The filter, rank 5, of shape + * `[filterDepth, filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideDepth, strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm used: + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + */ + function conv3dDerInput_(xShape, dy, filter, strides, pad) { + assert(xShape.length === dy.rank, function () { return "Length of inShape " + + ("(" + xShape.length + ") and rank of dy (" + dy.rank + ") must match"); }); + var xShape5D = xShape; + var dy5D = dy; + var reshapedTo5D = false; + if (dy.rank === 4) { + reshapedTo5D = true; + dy5D = dy.as5D(1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]); + xShape5D = [1, xShape[0], xShape[1], xShape[2], xShape[3]]; + } + var inDepth = xShape5D[4]; + var outDepth = dy5D.shape[4]; + assert(xShape5D.length === 5, function () { + return "Error in conv3dDerInput: inShape must be length 5, but got length " + + (xShape5D.length + "."); + }); + assert(dy5D.rank === 5, function () { return "Error in conv3dDerInput: dy must be rank 5, but got " + + ("rank " + dy5D.rank); }); + assert(filter.rank === 5, function () { return "Error in conv3dDerInput: filter must be rank 5, but got " + + ("rank " + filter.rank); }); + assert(inDepth === filter.shape[3], function () { return "Error in conv3dDerInput: depth of input (" + inDepth + ") must " + + ("match input depth for filter " + filter.shape[3] + "."); }); + assert(outDepth === filter.shape[4], function () { return "Error in conv3dDerInput: depth of output (" + outDepth + ") must " + + ("match output depth for filter " + filter.shape[4] + "."); }); + var dilations = 1; + var convInfo = computeConv3DInfo(xShape5D, filter.shape, strides, dilations, pad); + var res = ENGINE.runKernelFunc(function (backend) { return backend.conv3dDerInput(dy5D, filter, convInfo); }, { dy5D: dy5D }); + if (reshapedTo5D) { + return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]); + } + return res; + } + /** + * Computes the derivative of the filter of a 3D convolution. + * + * @param x The input tensor, of rank 5 or rank 4 of shape + * [batch, depth, height, width, inChannels]. If rank 4, batch of 1 is + * assumed. + * @param dy The dy image, of rank 5 or rank 4, of shape + * [batch, depth, height, width, outDepth]. If rank 4, batch of 1 is + * assumed. + * @param filterShape The shape of the filter, length 5, + * [filterDepth, filterHeight, filterWidth, inDepth, outDepth]. + * @param strides The strides of the convolution: [strideDepth, strideHeight, + * strideWidth]. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + */ + function conv3dDerFilter_(x, dy, filterShape, strides, pad) { + var x5D = x; + if (x.rank === 4) { + x5D = x.as5D(1, x.shape[0], x.shape[1], x.shape[2], x.shape[3]); + } + var dy5D = dy; + if (dy5D.rank === 4) { + dy5D = dy.as5D(1, dy.shape[0], dy.shape[1], dy.shape[2], dy.shape[3]); + } + assert(x5D.rank === 5, function () { return "Error in conv3dDerFilter: input must be rank 5, but got shape " + + (x5D.shape + "."); }); + assert(dy5D.rank === 5, function () { return "Error in conv3dDerFilter: dy must be rank 5, but got shape " + + (dy5D.shape + "."); }); + assert(filterShape.length === 5, function () { return "Error in conv3dDerFilter: filterShape must be length 5, but got " + + (filterShape + "."); }); + assert(x5D.shape[4] === filterShape[3], function () { return "Error in conv3dDerFilter: depth of input " + x5D.shape[4] + ") must " + + ("match input depth in filter (" + filterShape[3] + "."); }); + assert(dy5D.shape[4] === filterShape[4], function () { return "Error in conv3dDerFilter: depth of dy (" + dy5D.shape[4] + ") must " + + ("match output depth for filter (" + filterShape[4] + ")."); }); + var dilations = 1; + var convInfo = computeConv3DInfo(x5D.shape, filterShape, strides, dilations, pad); + return ENGINE.runKernelFunc(function (backend) { return backend.conv3dDerFilter(x5D, dy5D, convInfo); }, { x5D: x5D, dy5D: dy5D }); + } + /** + * Computes the transposed 3D convolution of a volume, also known as a + * deconvolution. + * + * @param x The input image, of rank 5 or rank 4, of shape + * `[batch, depth, height, width, inDepth]`. If rank 4, batch of 1 is assumed. + * @param filter The filter, rank 4, of shape + * `[depth, filterHeight, filterWidth, outDepth, inDepth]`. + * `inDepth` must match `inDepth` in `x`. + * @param outputShape Output shape, of rank 5 or rank 4: + * `[batch, depth, height, width, outDepth]`. If rank 3, batch of 1 is + * assumed. + * @param strides The strides of the original convolution: + * `[strideDepth, strideHeight, strideWidth]`. + * @param pad The type of padding algorithm used in the non-transpose version + * of the op. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function conv3dTranspose_(x, filter, outputShape, strides, pad) { + var $x = convertToTensor(x, 'x', 'conv3dTranspose'); + var $filter = convertToTensor(filter, 'filter', 'conv3dTranspose'); + return conv3dDerInput_(outputShape, $x, $filter, strides, pad); + } + var conv1d = op({ conv1d_: conv1d_ }); + var conv2d = op({ conv2d_: conv2d_ }); + var conv3d = op({ conv3d_: conv3d_ }); + var conv2dDerFilter = op({ conv2dDerFilter_: conv2dDerFilter_ }); + var conv2dDerInput = op({ conv2dDerInput_: conv2dDerInput_ }); + var depthwiseConv2d = op({ depthwiseConv2d_: depthwiseConv2d_ }); + var depthwiseConv2dDerInput = op({ depthwiseConv2dDerInput_: depthwiseConv2dDerInput_ }); + var depthwiseConv2dDerFilter = op({ depthwiseConv2dDerFilter_: depthwiseConv2dDerFilter_ }); + var separableConv2d = op({ separableConv2d_: separableConv2d_ }); + var conv2dTranspose = op({ conv2dTranspose_: conv2dTranspose_ }); + var conv3dTranspose = op({ conv3dTranspose_: conv3dTranspose_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the dot product of two matrices, A * B. These must be matrices. + * + * ```js + * const a = tf.tensor2d([1, 2], [1, 2]); + * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * a.matMul(b).print(); // or tf.matMul(a, b) + * ``` + * @param a First matrix in dot product operation. + * @param b Second matrix in dot product operation. + * @param transposeA If true, `a` is transposed before multiplication. + * @param transposeB If true, `b` is transposed before multiplication. + */ + /** @doc {heading: 'Operations', subheading: 'Matrices'} */ + function matMul_(a, b, transposeA, transposeB) { + var _a; + if (transposeA === void 0) { transposeA = false; } + if (transposeB === void 0) { transposeB = false; } + var $a = convertToTensor(a, 'a', 'matMul'); + var $b = convertToTensor(b, 'b', 'matMul'); + _a = makeTypesMatch($a, $b), $a = _a[0], $b = _a[1]; + var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; + var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; + var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; + var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; + var outerDimsA = $a.shape.slice(0, -2); + var outerDimsB = $b.shape.slice(0, -2); + var batchDimA = sizeFromShape(outerDimsA); + var batchDimB = sizeFromShape(outerDimsB); + assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, function () { return "Error in matMul: inputs must have the same rank of at least 2, " + + ("got ranks " + $a.rank + " and " + $b.rank + "."); }); + assert(arraysEqual(outerDimsA, outerDimsB), function () { return "Error in matMul: outer dimensions (" + outerDimsA + ") and (" + + (outerDimsB + ") of Tensors with shapes " + $a.shape + " and ") + + ($b.shape + " must match."); }); + assert(innerShapeA === innerShapeB, function () { return "Error in matMul: inner shapes (" + innerShapeA + ") and (" + + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + + ($b.shape + " and transposeA=" + transposeA) + + (" and transposeB=" + transposeB + " must match."); }); + var outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); + var a3D = transposeA ? $a.as3D(batchDimA, innerShapeA, outerShapeA) : + $a.as3D(batchDimA, outerShapeA, innerShapeA); + var b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) : + $b.as3D(batchDimB, innerShapeB, outerShapeB); + var grad = function (dy, saved) { + var _a = saved, a3D = _a[0], b3D = _a[1]; + if (!transposeA && !transposeB) { + return { + a: function () { return dy.matMul(b3D, false, true); }, + b: function () { return a3D.matMul(dy, true, false); } + }; + } + else if (!transposeA && transposeB) { + return { + a: function () { return dy.matMul(b3D, false, false); }, + b: function () { return dy.matMul(a3D, true, false); } + }; + } + else if (transposeA && !transposeB) { + return { + a: function () { return b3D.matMul(dy, false, true); }, + b: function () { return a3D.matMul(dy, false, false); } + }; + } + else { + return { + a: function () { return b3D.matMul(dy, true, true); }, + b: function () { return dy.matMul(a3D, true, true); } + }; + } + }; + var attrs = { transposeA: transposeA, transposeB: transposeB }; + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.batchMatMul(a3D, b3D, transposeA, transposeB); + save([a3D, b3D]); + return res; + }, { a: a3D, b: b3D }, grad, 'BatchMatMul', attrs); + return res.reshape(outShape); + } + /** + * Computes the outer product of two vectors, `v1` and `v2`. + * + * ```js + * const a = tf.tensor1d([1, 2, 3]); + * const b = tf.tensor1d([3, 4, 5]); + * + * tf.outerProduct(a, b).print(); + * ``` + * @param v1 The first vector in the outer product operation. + * @param v2 The second vector in the outer product operation. + */ + /** @doc {heading: 'Operations', subheading: 'Matrices'} */ + function outerProduct_(v1, v2) { + var $v1 = convertToTensor(v1, 'v1', 'outerProduct'); + var $v2 = convertToTensor(v2, 'v2', 'outerProduct'); + assert($v1.rank === 1 && $v2.rank === 1, function () { return "Error in outerProduct: inputs must be rank 1, but got ranks " + + ($v1.rank + " and " + $v2.rank + "."); }); + return $v1.as2D(-1, 1).matMul($v2.as2D(1, -1)); + } + /** + * Computes the dot product of two matrices and/or vectors, `t1` and `t2`. + * + * ```js + * const a = tf.tensor1d([1, 2]); + * const b = tf.tensor2d([[1, 2], [3, 4]]); + * const c = tf.tensor2d([[1, 2, 3], [4, 5, 6]]); + * + * a.dot(b).print(); // or tf.dot(a, b) + * b.dot(a).print(); + * b.dot(c).print(); + * ``` + * @param t1 The first tensor in the dot operation. + * @param t2 The second tensor in the dot operation. + */ + /** @doc {heading: 'Operations', subheading: 'Matrices'} */ + function dot_(t1, t2) { + var $t1 = convertToTensor(t1, 't1', 'dot'); + var $t2 = convertToTensor(t2, 't2', 'dot'); + assert(($t1.rank === 1 || $t1.rank === 2) && ($t2.rank === 1 || $t2.rank === 2), function () { return "Error in dot: inputs must all be rank 1 or 2, but got ranks " + + ($t1.rank + " and " + $t2.rank + "."); }); + var t1Inner = ($t1.rank === 1 ? $t1.size : $t1.shape[1]); + var t2Inner = ($t2.rank === 1 ? $t2.size : $t2.shape[0]); + assert(t1Inner === t2Inner, function () { return "Error in dot: inner dimensions of inputs must match, but got " + + (t1Inner + " and " + t2Inner + "."); }); + if ($t1.rank === 1 && $t2.rank === 1) { + return $t1.as2D(1, -1).matMul($t2.as2D(-1, 1)).asScalar(); + } + else if ($t1.rank === 1 && $t2.rank === 2) { + return $t1.as2D(1, -1).matMul($t2.as2D($t2.shape[0], $t2.shape[1])).as1D(); + } + else if ($t1.rank === 2 && $t2.rank === 1) { + return $t1.matMul($t2.as2D(-1, 1)).as1D(); + } + else { + return $t1.matMul($t2.as2D($t2.shape[0], $t2.shape[1])); + } + } + var matMul = op({ matMul_: matMul_ }); + var dot = op({ dot_: dot_ }); + var outerProduct = op({ outerProduct_: outerProduct_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Reverses a `tf.Tensor1D`. + * + * @param x The input tensor. + */ + function reverse1d_(x) { + var $x = convertToTensor(x, 'x', 'reverse'); + assert($x.rank === 1, function () { return "Error in reverse1D: x must be rank 1 but got rank " + $x.rank + "."; }); + return reverse($x, 0); + } + /** + * Reverses a `tf.Tensor2D` along a specified axis. + * + * @param x The input tensor. + * @param axis The set of dimensions to reverse. Must be in the + * range [-rank(x), rank(x)). Defaults to all axes. + */ + function reverse2d_(x, axis) { + var $x = convertToTensor(x, 'x', 'reverse'); + assert($x.rank === 2, function () { return "Error in reverse2D: x must be rank 2 but got rank " + $x.rank + "."; }); + return reverse($x, axis); + } + /** + * Reverses a `tf.Tensor3D` along a specified axis. + * + * @param x The input tensor. + * @param axis The set of dimensions to reverse. Must be in the + * range [-rank(x), rank(x)). Defaults to all axes. + */ + function reverse3d_(x, axis) { + var $x = convertToTensor(x, 'x', 'reverse'); + assert($x.rank === 3, function () { return "Error in reverse3D: x must be rank 3 but got rank " + $x.rank + "."; }); + return reverse($x, axis); + } + /** + * Reverses a `tf.Tensor4D` along a specified axis. + * + * @param x The input tensor. + * @param axis The set of dimensions to reverse. Must be in the + * range [-rank(x), rank(x)). Defaults to all axes. + */ + function reverse4d_(x, axis) { + var $x = convertToTensor(x, 'x', 'reverse'); + assert($x.rank === 4, function () { return "Error in reverse4D: x must be rank 4 but got rank " + $x.rank + "."; }); + return reverse($x, axis); + } + /** + * Reverses a `tf.Tensor` along a specified axis. + * + * Also available are stricter rank-specific methods that assert that `x` is + * of the given rank: + * - `tf.reverse1d` + * - `tf.reverse2d` + * - `tf.reverse3d` + * - `tf.reverse4d` + * + * Except `tf.reverse1d` (which does not have axis param), all methods have + * same signature as this method. + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * + * x.reverse().print(); + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.reverse(axis).print(); + * ``` + * @param x The input tensor to be reversed. + * @param axis The set of dimensions to reverse. Must be in the + * range [-rank(x), rank(x)). Defaults to all axes. + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function reverse_(x, axis) { + var $x = convertToTensor(x, 'x', 'reverse'); + if ($x.rank === 0) { + return $x.clone(); + } + var axes = parseAxisParam(axis, $x.shape); + var grad = function (dy) { + return { $x: function () { return dy.reverse(axes); } }; + }; + var res = ENGINE.runKernelFunc(function (backend) { return backend.reverse($x, axes); }, { $x: $x }, grad); + return res.reshapeAs($x); + } + var reverse = op({ reverse_: reverse_ }); + var reverse1d = op({ reverse1d_: reverse1d_ }); + var reverse2d = op({ reverse2d_: reverse2d_ }); + var reverse3d = op({ reverse3d_: reverse3d_ }); + var reverse4d = op({ reverse4d_: reverse4d_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the 2D max pooling of an image. + * + * @param x The input tensor, of rank 4 or rank 3 of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + function maxPoolImpl_(x, filterSize, strides, dilations, pad, dimRoundingMode) { + var $x = convertToTensor(x, 'x', 'maxPool'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + if (dilations == null) { + dilations = [1, 1]; + } + assert(x4D.rank === 4, function () { return "Error in maxPool: input must be rank 4 but got rank " + x4D.rank + "."; }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPool: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computePool2DInfo(x4D.shape, filterSize, strides, dilations, pad, dimRoundingMode); + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + arraysEqual(convInfo.inShape, convInfo.outShape) && + convInfo.padInfo.type === 'VALID') { + return $x.clone(); + } + var grad = function (dy, saved) { + var x4D = saved[0], y = saved[1]; + return { + x: function () { return maxPoolBackprop(dy, x4D, y, filterSize, strides, dilations, pad); } + }; + }; + var inputsToSave = [x4D]; + var res = ENGINE.runKernelFunc(function (backend, save) { + var y = backend.maxPool(x4D, convInfo); + save([x4D, y]); + return y; + }, { x: x4D }, grad, 'MaxPool', convInfo, inputsToSave); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * Computes the 2D max pooling of an image. + * + * @param x The input tensor, of rank 4 or rank 3 of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function maxPool_(x, filterSize, strides, pad, dimRoundingMode) { + return maxPoolImpl_(x, filterSize, strides, 1, pad, dimRoundingMode); + } + /** + * Computes the 2D average pooling of an image. + * + * @param x The input tensor, of rank 4 or rank 3 of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in dilated pooling. Defaults to `[1, 1]`. If `dilations` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param pad The type of padding algorithm: + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + function avgPoolImpl_(x, filterSize, strides, dilations, pad, dimRoundingMode) { + var $x = convertToTensor(x, 'x', 'avgPool', 'float32'); + if (dilations == null) { + dilations = [1, 1]; + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in avgPool: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + assert(x4D.rank === 4, function () { return "Error in avgPool: x must be rank 4 but got rank " + x4D.rank + "."; }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in avgPool: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computePool2DInfo(x4D.shape, filterSize, strides, dilations, pad, dimRoundingMode); + if (convInfo.filterWidth === 1 && convInfo.filterHeight === 1 && + arraysEqual(convInfo.inShape, convInfo.outShape) && + convInfo.padInfo.type === 'VALID') { + return $x.clone(); + } + var grad = function (dy) { + return { + x: function () { return avgPoolBackprop(dy, x4D, filterSize, strides, dilations, pad); } + }; + }; + var res = ENGINE.runKernelFunc(function (backend) { return backend.avgPool(x4D, convInfo); }, { x: x4D }, grad, 'AvgPool', convInfo); + res = res.cast($x.dtype); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * Computes the 2D average pooling of an image. + * + * @param x The input tensor, of rank 4 or rank 3 of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param pad The type of padding algorithm: + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function avgPool_(x, filterSize, strides, pad, dimRoundingMode) { + return avgPoolImpl_(x, filterSize, strides, 1, pad, dimRoundingMode); + } + /** + * Performs an N-D pooling operation + * + * @param input The input tensor, of rank 4 or rank 3 of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param windowShape The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param poolingType The type of pooling, either 'max' or 'avg'. + * @param pad The type of padding algorithm: + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in dilated pooling. Defaults to `[1, 1]`. If `dilationRate` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function pool_(input, windowShape, poolingType, pad, dilations, strides) { + if (dilations == null) { + dilations = [1, 1]; + } + if (strides == null) { + strides = 1; + } + if (pad === 0) { + pad = 'valid'; + } + var $x = convertToTensor(input, 'x', 'maxPool'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in pool: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + var convInfo = computePool2DInfo(x4D.shape, windowShape, strides, dilations, pad); + var dilation = [convInfo.dilationHeight, convInfo.dilationWidth]; + // The following implementation does batchToSpace(pool(spaceToBatch(x))) + // whenever dilation > 1 since the TF kernels do not support dilation > 1. + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L1037 + var basePadding; + if (pad === 'same') { + basePadding = withSpaceToBatchBasePaddings([convInfo.filterHeight, convInfo.filterWidth], dilation); + } + else { + basePadding = [[0, 0], [0, 0]]; + } + var isDilationOne = dilation[0] === 1 && dilation[1] === 1; + var _a = requiredSpaceToBatchPaddings([convInfo.inHeight, convInfo.inWidth], dilation, basePadding), adjustedPadding = _a[0], adjustedCrops = _a[1]; + var convertedPad = isDilationOne ? pad : 'valid'; + var convertedX = isDilationOne ? x4D : spaceToBatchND(x4D, dilation, adjustedPadding); + var forwardOp = poolingType === 'avg' ? + function () { return avgPoolImpl_(convertedX, windowShape, strides, 1 /* dilation */, convertedPad); } : + function () { return maxPoolImpl_(convertedX, windowShape, strides, 1 /* dilation */, convertedPad); }; + var y = forwardOp(); + var res = isDilationOne ? y : batchToSpaceND(y, dilation, adjustedCrops); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * Computes the backprop of a 2D max pool. + * + * @param dy The dy error, of rank 4 or rank 3 of shape + * [batchSize, height, width, channels]. If rank 3, batch of 1 is + * assumed. + * @param input The original input image, of rank 4, of shape + * [batchSize, height, width, channels]. + * @param output The original output image, of rank 4, of shape + * [batchSize, outHeight, outWidth, channels]. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function maxPoolBackprop(dy, input, output, filterSize, strides, dilations, pad, dimRoundingMode) { + var $dy = convertToTensor(dy, 'dy', 'maxPoolBackprop'); + var $input = convertToTensor(input, 'input', 'maxPoolBackprop'); + var $output = convertToTensor(output, 'output', 'maxPoolBackprop'); + assert($input.rank === $dy.rank, function () { return "Rank of input (" + $input.rank + ") does not match rank of dy " + + ("(" + $dy.rank + ")"); }); + if (dilations == null) { + dilations = [1, 1]; + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { + return 'Error in maxPoolBackProp: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); + }); + assert($dy.rank === 4, function () { return "Error in maxPoolBackprop: dy must be rank 4 but got rank " + + ($dy.rank + "."); }); + assert($input.rank === 4, function () { return "Error in maxPoolBackprop: input must be rank 4 but got rank " + + ($input.rank + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPoolBackprop: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computePool2DInfo($input.shape, filterSize, strides, dilations, pad, dimRoundingMode); + var res = ENGINE.runKernelFunc(function (backend) { return backend.maxPoolBackprop($dy, $input, $output, convInfo); }, { $dy: $dy, $input: $input }); + return res; + } + /** + * Computes the backprop of an 2D avg pool. + * + * @param dy The dy error, of rank 4 or rank 3 of shape + * [batchSize, height, width, channels]. If rank 3, batch of 1 is + * assumed. + * @param input The input image, of rank 4 or rank 3 of shape + * [batchSize, height, width, channels]. If rank 3, batch of 1 is + * assumed. + * @param filterSize The filter size: `[filterHeight, filterWidth]`. If + * `filterSize` is a single number, then `filterHeight == filterWidth`. + * @param strides The strides of the pooling: `[strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + */ + function avgPoolBackprop(dy, input, filterSize, strides, dilations, pad) { + var $dy = convertToTensor(dy, 'dy', 'avgPoolBackprop'); + var $input = convertToTensor(input, 'input', 'avgPoolBackprop'); + assert($input.rank === $dy.rank, function () { return "Rank of input (" + $input.rank + ") does not match rank of dy (" + $dy.rank + ")"; }); + if (dilations == null) { + dilations = [1, 1]; + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { + return 'Error in avgPoolBackprop: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); + }); + var input4D = $input; + var dy4D = $dy; + var reshapedTo4D = false; + if ($input.rank === 3) { + reshapedTo4D = true; + input4D = $input.as4D(1, $input.shape[0], $input.shape[1], $input.shape[2]); + dy4D = $dy.as4D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2]); + } + assert(dy4D.rank === 4, function () { return "Error in avgPoolBackprop: dy must be rank 4 but got rank " + + (dy4D.rank + "."); }); + assert(input4D.rank === 4, function () { return "Error in avgPoolBackprop: input must be rank 4 but got rank " + + (input4D.rank + "."); }); + var convInfo = computePool2DInfo(input4D.shape, filterSize, strides, dilations, pad); + var res = ENGINE.runKernelFunc(function (backend) { return backend.avgPoolBackprop(dy4D, input4D, convInfo); }, { dy4D: dy4D, input4D: input4D }); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + // Helper function to compute crops and paddings for pool with dilation > 1. + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/array_ops.py#L2184 + function requiredSpaceToBatchPaddings(inputShape, blockShape, basePadding) { + var padStart = basePadding.map(function (b) { return b[0]; }); + var origPadEnd = basePadding.map(function (b) { return b[1]; }); + var fullInputShape = inputShape.concat(padStart, origPadEnd); + var padEndExtra = blockShape.map(function (b, i) { return (b - fullInputShape[i] % b) % b; }); + var padEnd = origPadEnd.map(function (s, i) { return s + padEndExtra[i]; }); + var paddings = blockShape.map(function (_, i) { return [padStart[i], padEnd[i]]; }); + var crops = blockShape.map(function (_, i) { return [0, padEndExtra[i]]; }); + return [paddings, crops]; + } + // Helper function to compute base paddings for pool with dilation > 1. + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/50f6bb67dc98c9b74630b6047aae7a4f8a40fd02/tensorflow/python/ops/nn_ops.py#L524 + function withSpaceToBatchBasePaddings(filterShape, dilation) { + // Spatial dimensions of the filters and the upsampled filters in which we + // introduce (rate - 1) zeros between consecutive filter values. + var dilatedFilterShape = filterShape.map(function (s, i) { + return s + (s - 1) * (dilation[i] - 1); + }); + var padExtraShape = dilatedFilterShape.map(function (s) { return s - 1; }); + // When padding is odd, we pad more at end, following the same + // convention as conv2d. + var padExtraStart = padExtraShape.map(function (s) { return Math.floor(s / 2); }); + var padExtraEnd = padExtraShape.map(function (s, i) { return s - padExtraStart[i]; }); + return padExtraShape.map(function (_, i) { + return [padExtraStart[i], padExtraEnd[i]]; + }); + } + /** + * Computes the 3D average pooling. + * + * ```js + * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]); + * const result = tf.avgPool3d(x, 2, 1, 'valid'); + * result.print(); + * ``` + * + * @param x The input tensor, of rank 5 or rank 4 of shape + * `[batch, depth, height, width, inChannels]`. + * @param filterSize The filter size: + * `[filterDepth, filterHeight, filterWidth]`. + * If `filterSize` is a single number, + * then `filterDepth == filterHeight == filterWidth`. + * @param strides The strides of the pooling: + * `[strideDepth, strideHeight, strideWidth]`. + * If `strides` is a single number, + * then `strideDepth == strideHeight == strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1*1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to + * "NDHWC". Specify the data format of the input and output data. With the + * default format "NDHWC", the data is stored in the order of: [batch, + * depth, height, width, channels]. Only "NDHWC" is currently supported. + * @param dilations The dilation rates: + * `[dilationDepth, dilationHeight, dilationWidth]` + * in which we sample input values across the depth, height and width + * dimensions in dilated pooling. + * Defaults to `[1, 1, 1]`. If `dilations` is a single number, + * then `dilationDepth == dilationHeight == dilationWidth`. + * If it is greater than 1, then all values of `strides` must be 1. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function avgPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat, dilations) { + if (dataFormat === void 0) { dataFormat = 'NDHWC'; } + var $x = convertToTensor(x, 'x', 'avgPool3d', 'float32'); + var x5D = $x; + var reshapedTo5D = false; + if ($x.rank === 4) { + reshapedTo5D = true; + x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]); + } + if (dilations == null) { + dilations = [1, 1, 1]; + } + assert(x5D.rank === 5, function () { return "Error in avgPool3d: x must be rank 5 but got rank " + x5D.rank + "."; }); + assert(dataFormat === 'NDHWC', function () { return "Error in avgPool3d: Only NDHWC is currently supported, " + + ("but got dataFormat of " + dataFormat); }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in avgPool3d: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in avgPool3d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computePool3DInfo(x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); + var grad = function (dy) { + return { + x: function () { return avgPool3dBackprop(dy, x5D, filterSize, strides, dilations, pad, dimRoundingMode); } + }; + }; + var res = ENGINE.runKernelFunc(function (backend) { return backend.avgPool3d(x5D, convInfo); }, { x: x5D }, grad); + res = res.cast(x5D.dtype); + if (reshapedTo5D) { + return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]); + } + return res; + } + /** + * Computes the backprop of a 3d avg pool. + * + * @param dy The dy error, of rank 5 of shape + * [batchSize, depth, height, width, channels]. + * assumed. + * @param input The original input image, of rank 5 or rank4 of shape + * [batchSize, depth, height, width, channels]. + * @param filterSize The filter size: + * `[filterDepth, filterHeight, filterWidth]`. + * `filterSize` is a single number, + * then `filterDepth == filterHeight == filterWidth`. + * @param strides The strides of the pooling: + * `[strideDepth, strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param dilations The dilation rates: + * `[dilationDepth, dilationHeight, dilationWidth]` + * in which we sample input values across the depth, height and width + * dimensions in dilated pooling. + * Defaults to `[1, 1, 1]`. If `dilations` is a single number, + * then `dilationDepth == dilationHeight == dilationWidth`. + * If it is greater than 1, then all values of `strides` must be 1. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function avgPool3dBackprop(dy, input, filterSize, strides, dilations, pad, dimRoundingMode) { + var $dy = convertToTensor(dy, 'dy', 'avgPool3dBackprop'); + var $input = convertToTensor(input, 'input', 'avgPool3dBackprop'); + var dy5D = $dy; + var input5D = $input; + var reshapedTo5D = false; + if ($input.rank === 4) { + reshapedTo5D = true; + dy5D = $dy.as5D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]); + input5D = $input.as5D(1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]); + } + assert(dy5D.rank === 5, function () { return "Error in avgPool3dBackprop: dy must be rank 5 but got rank " + + (dy5D.rank + "."); }); + assert(input5D.rank === 5, function () { return "Error in avgPool3dBackprop: input must be rank 5 but got rank " + + (input5D.rank + "."); }); + if (dilations == null) { + dilations = [1, 1, 1]; + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in avgPool3dBackprop: Either strides or dilations ' + + ("must be 1. Got strides " + strides + " and dilations '" + dilations + "'"); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPool3dBackprop: pad must be an integer when " + + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode); + var res = ENGINE.runKernelFunc(function (backend) { return backend.avgPool3dBackprop(dy5D, input5D, convInfo); }, { dy5D: dy5D, input5D: input5D }); + if (reshapedTo5D) { + return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]); + } + return res; + } + /** + * Computes the 3D max pooling. + * + * ```js + * const x = tf.tensor5d([1, 2, 3, 4, 5, 6, 7, 8], [1, 2, 2, 2, 1]); + * const result = tf.maxPool3d(x, 2, 1, 'valid'); + * result.print(); + * ``` + * + * @param x The input tensor, of rank 5 or rank 4 of shape + * `[batch, depth, height, width, inChannels]`. + * @param filterSize The filter size: + * `[filterDepth, filterHeight, filterWidth]`. + * If `filterSize` is a single number, + * then `filterDepth == filterHeight == filterWidth`. + * @param strides The strides of the pooling: + * `[strideDepth, strideHeight, strideWidth]`. + * If `strides` is a single number, + * then `strideDepth == strideHeight == strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1*1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + * @param dataFormat An optional string from: "NDHWC", "NCDHW". Defaults to + * "NDHWC". Specify the data format of the input and output data. With the + * default format "NDHWC", the data is stored in the order of: [batch, + * depth, height, width, channels]. Only "NDHWC" is currently supported. + * @param dilations The dilation rates: + * `[dilationDepth, dilationHeight, dilationWidth]` + * in which we sample input values across the depth, height and width + * dimensions in dilated pooling. + * Defaults to `[1, 1, 1]`. If `dilations` is a single number, + * then `dilationDepth == dilationHeight == dilationWidth`. + * If it is greater than 1, then all values of `strides` must be 1. + */ + /** @doc {heading: 'Operations', subheading: 'Convolution'} */ + function maxPool3d_(x, filterSize, strides, pad, dimRoundingMode, dataFormat, dilations) { + if (dataFormat === void 0) { dataFormat = 'NDHWC'; } + var $x = convertToTensor(x, 'x', 'maxPool3d'); + var x5D = $x; + var reshapedTo5D = false; + if ($x.rank === 4) { + reshapedTo5D = true; + x5D = $x.as5D(1, $x.shape[0], $x.shape[1], $x.shape[2], $x.shape[3]); + } + if (dilations == null) { + dilations = [1, 1, 1]; + } + assert(x5D.rank === 5, function () { return "Error in maxPool3d: x must be rank 5 but got rank " + x5D.rank + "."; }); + assert(dataFormat === 'NDHWC', function () { return "Error in maxPool3d: Only NDHWC is currently supported, " + + ("but got dataFormat of " + dataFormat); }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool3d: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPool3d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computePool3DInfo(x5D.shape, filterSize, strides, dilations, pad, dimRoundingMode, dataFormat); + var grad = function (dy, saved) { + var x5D = saved[0], y = saved[1]; + return { + x: function () { return maxPool3dBackprop(dy, x5D, y, filterSize, strides, dilations, pad, dimRoundingMode); } + }; + }; + var res = ENGINE.runKernelFunc(function (backend, save) { + var y = backend.maxPool3d(x5D, convInfo); + save([x5D, y]); + return y; + }, { x: x5D }, grad); + if (reshapedTo5D) { + return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]); + } + return res; + } + /** + * Computes the backprop of a 3d max pool. + * + * @param dy The dy error, of rank 5 of shape + * [batchSize, depth, height, width, channels]. + * assumed. + * @param input The original input image, of rank 5 or rank 4 of shape + * [batchSize, depth, height, width, channels]. + * @param output The original output image, of rank 5 of shape + * [batchSize, outDepth, outHeight, outWidth, channels]. + * @param filterSize The filter size: + * `[filterDepth, filterHeight, filterWidth]`. + * `filterSize` is a single number, + * then `filterDepth == filterHeight == filterWidth`. + * @param strides The strides of the pooling: + * `[strideDepth, strideHeight, strideWidth]`. If + * `strides` is a single number, then `strideHeight == strideWidth`. + * @param dilations The dilation rates: + * `[dilationDepth, dilationHeight, dilationWidth]` + * in which we sample input values across the depth, height and width + * dimensions in dilated pooling. + * Defaults to `[1, 1, 1]`. If `dilations` is a single number, + * then `dilationDepth == dilationHeight == dilationWidth`. + * If it is greater than 1, then all values of `strides` must be 1. + * @param pad A string from: 'same', 'valid'. The type of padding algorithm + * used in the forward prop of the op. + * @param dimRoundingMode A string from: 'ceil', 'round', 'floor'. The + * rounding mode used when computing output dimensions if pad is a + * number. If none is provided, it will not round and error if the output + * is of fractional size. + */ + function maxPool3dBackprop(dy, input, output, filterSize, strides, dilations, pad, dimRoundingMode) { + var $dy = convertToTensor(dy, 'dy', 'maxPool3dBackprop'); + var $input = convertToTensor(input, 'input', 'maxPool3dBackprop'); + var $output = convertToTensor(output, 'output', 'maxPool3dBackprop'); + var dy5D = $dy; + var input5D = $input; + var output5D = $output; + var reshapedTo5D = false; + if ($input.rank === 4) { + reshapedTo5D = true; + dy5D = $dy.as5D(1, $dy.shape[0], $dy.shape[1], $dy.shape[2], $dy.shape[3]); + input5D = $input.as5D(1, $input.shape[0], $input.shape[1], $input.shape[2], $input.shape[3]); + output5D = $output.as5D(1, $output.shape[0], $output.shape[1], $output.shape[2], $output.shape[3]); + } + assert(dy5D.rank === 5, function () { return "Error in maxPool3dBackprop: dy must be rank 5 but got rank " + + (dy5D.rank + "."); }); + assert(input5D.rank === 5, function () { return "Error in maxPool3dBackprop: input must be rank 5 but got rank " + + (input5D.rank + "."); }); + assert(output5D.rank === 5, function () { return "Error in maxPool3dBackprop: output must be rank 5 but got rank " + + (output5D.rank + "."); }); + if (dilations == null) { + dilations = [1, 1, 1]; + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in maxPool3dBackprop: Either strides or dilations ' + + ("must be 1. Got strides " + strides + " and dilations '" + dilations + "'"); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in maxPool3dBackprop: pad must be an integer when " + + ("using, dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computePool3DInfo(input5D.shape, filterSize, strides, dilations, pad, dimRoundingMode); + var res = ENGINE.runKernelFunc(function (backend) { return backend.maxPool3dBackprop(dy5D, input5D, output5D, convInfo); }, { dy5D: dy5D, input5D: input5D }); + if (reshapedTo5D) { + return res.as4D(res.shape[1], res.shape[2], res.shape[3], res.shape[4]); + } + return res; + } + var maxPool = op({ maxPool_: maxPool_ }); + var avgPool = op({ avgPool_: avgPool_ }); + var pool = op({ pool_: pool_ }); + var maxPool3d = op({ maxPool3d_: maxPool3d_ }); + var avgPool3d = op({ avgPool3d_: avgPool3d_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Extracts a 1D slice from 1D array starting at coordinates `begin` and is + * of length `size`. See `slice` for details. + */ + function slice1d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice1d'); + assert($x.rank === 1, function () { + return "slice1d expects a rank-1 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, [begin], [size]); + } + /** + * Extracts a 2D slice from a 2D array starting at coordinates `begin` and + * is of size `size`. See `slice` for details. + */ + function slice2d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice2d'); + assert($x.rank === 2, function () { + return "slice2d expects a rank-2 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, begin, size); + } + /** + * Extracts a 3D slice from a 3D array starting at coordinates `begin` and + * is of size `size`. See `slice` for details. + */ + function slice3d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice3d'); + assert($x.rank === 3, function () { + return "slice3d expects a rank-3 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, begin, size); + } + /** + * Extracts a 4D slice from a 4D array starting at coordinates `begin` and + * is of size `size`. See `slice` for details. + */ + function slice4d_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice4d'); + assert($x.rank === 4, function () { + return "slice4d expects a rank-4 tensor, but got a rank-" + $x.rank + " tensor"; + }); + return slice($x, begin, size); + } + /** + * Extracts a slice from a `tf.Tensor` starting at coordinates `begin` + * and is of size `size`. + * + * Also available are stricter rank-specific methods with the same signature + * as this method that assert that `x` is of the given rank: + * - `tf.slice1d` + * - `tf.slice2d` + * - `tf.slice3d` + * - `tf.slice4d` + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * + * x.slice([1], [2]).print(); + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * x.slice([1, 0], [1, 2]).print(); + * ``` + * @param x The input `tf.Tensor` to slice from. + * @param begin The coordinates to start the slice from. The length can be + * less than the rank of x - the rest of the axes will have implicit 0 as + * start. Can also be a single number, in which case it specifies the + * first axis. + * @param size The size of the slice. The length can be less than the rank of + * x - the rest of the axes will have implicit -1. A value of -1 requests + * the rest of the dimensions in the axis. Can also be a single number, + * in which case it specifies the size of the first axis. + */ + /** @doc {heading: 'Tensors', subheading: 'Slicing and Joining'} */ + function slice_(x, begin, size) { + var $x = convertToTensor(x, 'x', 'slice'); + if ($x.rank === 0) { + throw new Error('Slicing scalar is not possible'); + } + // The following logic allows for more ergonomic calls. + var begin_; + if (typeof begin === 'number') { + begin_ = [begin].concat(new Array($x.rank - 1).fill(0)); + } + else if (begin.length < $x.rank) { + begin_ = begin.concat(new Array($x.rank - begin.length).fill(0)); + } + else { + begin_ = begin.slice(); + } + begin_.forEach(function (d) { + assert(d !== -1, function () { return 'slice() does not support negative begin indexing.'; }); + }); + var size_; + if (size == null) { + size_ = new Array($x.rank).fill(-1); + } + else if (typeof size === 'number') { + size_ = [size].concat(new Array($x.rank - 1).fill(-1)); + } + else if (size.length < $x.rank) { + size_ = size.concat(new Array($x.rank - size.length).fill(-1)); + } + else { + size_ = size; + } + size_ = size_.map(function (d, i) { + if (d >= 0) { + return d; + } + else { + assert(d === -1, function () { return "Negative size values should be exactly -1 but got " + + (d + " for the slice() size at index " + i + "."); }); + return $x.shape[i] - begin_[i]; + } + }); + assertParamsValid($x, begin_, size_); + var inputShape = $x.shape; + var grad = function (dy) { + // Create an Nx2 padding where the first column represents how many + // zeros are prepended (at start) for each dimension, and the second + // column indicates how many zeros are appended (at end). + // The number of zeros to append is the shape of the input + // elementwise-subtracted by both the begin vector and sizes vector. + var paddings = []; + for (var i = 0; i < dy.rank; i++) { + paddings.push([begin_[i], inputShape[i] - begin_[i] - size_[i]]); + } + return { x: function () { return dy.pad(paddings); } }; + }; + var attrs = { begin: begin_, size: size_ }; + return ENGINE.runKernelFunc(function (backend) { return backend.slice($x, begin_, size_); }, { x: $x }, grad, 'Slice', attrs); + } + var slice = op({ slice_: slice_ }); + var slice1d = op({ slice1d_: slice1d_ }); + var slice2d = op({ slice2d_: slice2d_ }); + var slice3d = op({ slice3d_: slice3d_ }); + var slice4d = op({ slice4d_: slice4d_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the log(sum(exp(elements across the reduction dimensions)). + * + * Reduces the input along the dimensions given in `axis`. Unless `keepDims` + * is true, the rank of the array is reduced by 1 for each entry in `axis`. + * If `keepDims` is true, the reduced dimensions are retained with length 1. + * If `axis` has no entries, all dimensions are reduced, and an array with a + * single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.logSumExp().print(); // or tf.logSumExp(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.logSumExp(axis).print(); // or tf.logSumExp(a, axis) + * ``` + * @param x The input tensor. + * @param axis The dimension(s) to reduce. If null (the default), + * reduces all dimensions. + * @param keepDims If true, retains reduced dimensions with length + * of 1. Defaults to false. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function logSumExp_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'logSumExp'); + var axes = parseAxisParam(axis, $x.shape); + var xMax = $x.max(axes, true /* keepDims */); + var a = $x.sub(xMax); + var b = a.exp(); + var c = b.sum(axes); + var d = c.log(); + var res = xMax.reshape(d.shape).add(d); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, axes); + return res.reshape(newShape); + } + return res; + } + /** + * Computes the sum of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If axes has no entries, all dimensions are reduced, and a + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.sum().print(); // or tf.sum(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.sum(axis).print(); // or tf.sum(x, axis) + * ``` + * + * @param x The input tensor to compute the sum over. If the dtype is `bool` + * it will be converted to `int32` and the output dtype will be `int32`. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function sum_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'sum'); + if ($x.dtype === 'bool') { + $x = $x.toInt(); + } + var axes = parseAxisParam(axis, $x.shape); + // Use a custom gradient to bypass 2 gradient backprops since sum is used + // extremely often. + var customOp = customGrad(function (x) { + var permutation = getAxesPermutation(axes, x.rank); + var reductionAxes = axes; + var permutedX = x; + if (permutation != null) { + permutedX = x.transpose(permutation); + reductionAxes = getInnerMostAxes(reductionAxes.length, x.rank); + } + var value = ENGINE.runKernelFunc(function (backend) { return backend.sum(permutedX, reductionAxes); }, { permutedX: permutedX }); + if (keepDims) { + var newShape = expandShapeToKeepDim(value.shape, axes); + value = value.reshape(newShape); + } + var gradFunc = function (dy) { + var expandedDyShape = x.shape.slice(); + axes.forEach(function (axis) { + expandedDyShape[axis] = 1; + }); + var expandedDy = dy.reshape(expandedDyShape); + var derX = expandedDy.mul(ones$1(x.shape, 'float32')); + return derX; + }; + return { value: value, gradFunc: gradFunc }; + }); + return customOp($x); + } + /** + * Computes the product of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and a + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.prod().print(); // or tf.prod(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.prod(axis).print(); // or tf.prod(x, axis) + * ``` + * + * @param x The input tensor to compute the product over. If the dtype is `bool` + * it will be converted to `int32` and the output dtype will be `int32`. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function prod_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'prod'); + if ($x.dtype === 'bool') { + $x = $x.toInt(); + } + var axes = parseAxisParam(axis, $x.shape); + var permutation = getAxesPermutation(axes, $x.rank); + var reductionAxes = axes; + var permutedX = $x; + if (permutation != null) { + permutedX = $x.transpose(permutation); + reductionAxes = getInnerMostAxes(reductionAxes.length, $x.rank); + } + var value = ENGINE.runKernelFunc(function (backend) { return backend.prod(permutedX, reductionAxes); }, { permutedX: permutedX }); + if (keepDims) { + var newShape = expandShapeToKeepDim(value.shape, axes); + value = value.reshape(newShape); + } + return value; + } + /** + * Computes the mean of elements across dimensions of a `tf.Tensor`. + * + * Reduces `x` along the dimensions given in `axis`. Unless `keepDims` is + * true, the rank of the `tf.Tensor` is reduced by 1 for each entry in `axis`. + * If `keepDims` is true, the reduced dimensions are retained with length 1. + * If `axis` has no entries, all dimensions are reduced, and a `tf.Tensor` with + * a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.mean().print(); // or tf.mean(a) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.mean(axis).print(); // or tf.mean(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function mean_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'mean'); + var axes = parseAxisParam(axis, $x.shape); + var shapes = computeOutAndReduceShapes($x.shape, axes); + var reduceShape = shapes[1]; + var reduceSize = sizeFromShape(reduceShape); + // Use a custom gradient to bypass 2 gradient backprops since mean is used + // extremely often. + var customOp = customGrad(function (x) { + var reduceSizeScalar = scalar(reduceSize); + // Cast if needed. + var xReduce = reduceSizeScalar.dtype === x.dtype ? x : x.cast(reduceSizeScalar.dtype); + var res = xReduce.div(reduceSizeScalar); + var value = res.sum(axis, keepDims); + var gradFunc = function (dy) { + var expandedDyShape = x.shape.slice(); + axes.forEach(function (axis) { + expandedDyShape[axis] = 1; + }); + var expandedDy = dy.reshape(expandedDyShape); + var derX = expandedDy.mul(ones$1(x.shape, 'float32')).div(reduceSize); + return derX; + }; + return { value: value, gradFunc: gradFunc }; + }); + return customOp($x); + } + /** + * Gradient helper function for the min and max operations. + */ + function gradForMinAndMax(dy, y, xOrig, origAxes, permutedAxes) { + if (y.rank < xOrig.rank) { + y = y.reshape(expandShapeToKeepDim(y.shape, origAxes)); + } + if (dy.rank < xOrig.rank) { + dy = dy.reshape(expandShapeToKeepDim(dy.shape, origAxes)); + } + return { + x: function () { + var dx = dy.mul(xOrig.equal(y).cast(dy.dtype)); + return permutedAxes == null ? dx : dx.transpose(permutedAxes); + } + }; + } + /** + * Computes the minimum value from the input. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the array is reduced by 1 for each entry in `axes`. + * If `keepDims` is true, the reduced dimensions are retained with length 1. + * If `axes` has no entries, all dimensions are reduced, and an array with a + * single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.min().print(); // or tf.min(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.min(axis).print(); // or tf.min(x, axis) + * ``` + * + * @param x The input Tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function min_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'min'); + var xOrig = $x; + var origAxes = parseAxisParam(axis, $x.shape); + var axes = origAxes; + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var grad = function (dy, saved) { + return gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); + }; + var inputsToSave = [$x]; + var outputsToSave = [true]; + var res = ENGINE.runKernelFunc(function (backend, save) { + var y = backend.min($x, axes); + save([xOrig, y]); + return y; + }, { x: $x }, grad, 'Min', { axes: axes }, inputsToSave, outputsToSave); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, origAxes); + res = res.reshape(newShape); + } + return res; + } + /** + * Computes the maximum of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.max().print(); // or tf.max(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * + * const axis = 1; + * x.max(axis).print(); // or tf.max(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function max_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'max'); + var xOrig = $x; + var origAxes = parseAxisParam(axis, $x.shape); + var axes = origAxes; + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var grad = function (dy, saved) { + return gradForMinAndMax(dy, saved[1], saved[0], origAxes, permutedAxes); + }; + var inputsToSave = [$x]; + var outputsToSave = [true]; + var res = ENGINE.runKernelFunc(function (backend, save) { + var y = backend.max($x, axes); + save([xOrig, y]); + return y; + }, { x: $x }, grad, 'Max', { axes: axes }, inputsToSave, outputsToSave); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, origAxes); + res = res.reshape(newShape); + } + return res; + } + /** + * Returns the indices of the minimum values along an `axis`. + * + * The result has the same shape as `input` with the dimension along `axis` + * removed. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.argMin().print(); // or tf.argMin(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]); + * + * const axis = 1; + * x.argMin(axis).print(); // or tf.argMin(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension). + * + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function argMin_(x, axis) { + if (axis === void 0) { axis = 0; } + var $x = convertToTensor(x, 'x', 'argMin'); + if (axis == null) { + axis = 0; + } + var axes = parseAxisParam(axis, $x.shape); + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return zerosLike($x); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.argMin($x, axes[0]); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Returns the indices of the maximum values along an `axis`. + * + * The result has the same shape as `input` with the dimension along `axis` + * removed. + * + * ```js + * const x = tf.tensor1d([1, 2, 3]); + * + * x.argMax().print(); // or tf.argMax(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 2, 4, 3], [2, 2]); + * + * const axis = 1; + * x.argMax(axis).print(); // or tf.argMax(x, axis) + * ``` + * + * @param x The input tensor. + * @param axis The dimension to reduce. Defaults to 0 (outer-most dimension). + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function argMax_(x, axis) { + if (axis === void 0) { axis = 0; } + var $x = convertToTensor(x, 'x', 'argMax'); + if (axis == null) { + axis = 0; + } + var axes = parseAxisParam(axis, $x.shape); + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return zerosLike($x); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.argMax($x, axes[0]); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes the logical and of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 1, 1], 'bool'); + * + * x.all().print(); // or tf.all(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool'); + * + * const axis = 1; + * x.all(axis).print(); // or tf.all(x, axis) + * ``` + * + * @param x The input tensor. Must be of dtype bool. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function all_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'all', 'bool'); + var origAxes = parseAxisParam(axis, $x.shape); + var axes = origAxes; + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var res = ENGINE.runKernelFunc(function (backend) { return backend.all($x, axes); }, { $x: $x }); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, origAxes); + return res.reshape(newShape); + } + return res; + } + /** + * Computes the logical or of elements across dimensions of a `tf.Tensor`. + * + * Reduces the input along the dimensions given in `axes`. Unless `keepDims` + * is true, the rank of the `tf.Tensor` is reduced by 1 for each entry in + * `axes`. If `keepDims` is true, the reduced dimensions are retained with + * length 1. If `axes` has no entries, all dimensions are reduced, and an + * `tf.Tensor` with a single element is returned. + * + * ```js + * const x = tf.tensor1d([1, 1, 1], 'bool'); + * + * x.any().print(); // or tf.any(x) + * ``` + * + * ```js + * const x = tf.tensor2d([1, 1, 0, 0], [2, 2], 'bool'); + * + * const axis = 1; + * x.any(axis).print(); // or tf.any(x, axis) + * ``` + * + * @param x The input tensor. Must be of dtype bool. + * @param axis The dimension(s) to reduce. By default it reduces + * all dimensions. + * @param keepDims If true, retains reduced dimensions with size 1. + */ + /** @doc {heading: 'Operations', subheading: 'Reduction'} */ + function any_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + var $x = convertToTensor(x, 'x', 'any', 'bool'); + var origAxes = parseAxisParam(axis, $x.shape); + var axes = origAxes; + var permutedAxes = getAxesPermutation(axes, $x.rank); + if (permutedAxes != null) { + $x = $x.transpose(permutedAxes); + axes = getInnerMostAxes(axes.length, $x.rank); + } + var res = ENGINE.runKernelFunc(function (backend) { return backend.any($x, axes); }, { $x: $x }); + if (keepDims) { + var newShape = expandShapeToKeepDim(res.shape, origAxes); + return res.reshape(newShape); + } + return res; + } + /** + * Calculates the mean and variance of `x`. The mean and variance are + * calculated by aggregating the contents of `x` across `axes`. If `x` is + * 1-D and `axes = [0]` this is just the mean and variance of a vector. + * + * @param x The input tensor. + * @param axis The dimension(s) along with to compute mean and + * variance. By default it reduces all dimensions. + * @param keepDims If true, the moments have the same dimensionality as the + * input. + * @return An object with two keys: `mean` and `variance`. + */ + /** @doc {heading: 'Operations', subheading: 'Normalization'} */ + function moments_(x, axis, keepDims) { + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + x = convertToTensor(x, 'x', 'moments'); + var axes = parseAxisParam(axis, x.shape); + var mean = x.mean(axes, keepDims); + var keepDimsShape = mean.shape; + if (!keepDims) { + keepDimsShape = expandShapeToKeepDim(mean.shape, axes); + } + var devSquared = x.toFloat().sub(mean.reshape(keepDimsShape)).square(); + var variance = devSquared.mean(axes, keepDims); + return { mean: mean, variance: variance }; + } + var all = op({ all_: all_ }); + // tslint:disable-next-line:variable-name + var any = op({ any_: any_ }); + var argMax = op({ argMax_: argMax_ }); + var argMin = op({ argMin_: argMin_ }); + var logSumExp = op({ logSumExp_: logSumExp_ }); + var max = op({ max_: max_ }); + var mean = op({ mean_: mean_ }); + var min = op({ min_: min_ }); + var moments = op({ moments_: moments_ }); + var sum$1 = op({ sum_: sum_ }); + var prod = op({ prod_: prod_ }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes rectified linear element-wise: `max(x, 0)`. + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.relu().print(); // or tf.relu(x) + * ``` + * @param x The input tensor. If the dtype is `bool`, the output dtype will be + * `int32'. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function relu_(x) { + var $x = convertToTensor(x, 'x', 'relu'); + if ($x.dtype === 'bool') { + return $x.toInt(); + } + var grad = function (dy, saved) { + var $x = saved[0]; + return { $x: function () { return dy.mulStrict($x.step().toFloat()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.relu($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes rectified linear 6 element-wise: `min(max(x, 0), 6)`. + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 8]); + * + * x.relu6().print(); // or tf.relu6(x) + * ``` + * @param x The input tensor. If the dtype is `bool`, the output dtype will be + * `int32'. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function relu6_(x) { + var $x = convertToTensor(x, 'x', 'relu6'); + if ($x.dtype === 'bool') { + return $x.toInt(); + } + var grad = function (dy, saved) { + var $x = saved[0]; + var mask = $x.lessEqual(6).mul($x.step()); + return { $x: function () { return dy.mulStrict(mask.toFloat()); } }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.relu6($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes exponential linear element-wise: `x > 0 ? e ^ x - 1 : 0`. + * + * ```js + * const x = tf.tensor1d([-1, 1, -3, 2]); + * + * x.elu().print(); // or tf.elu(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function elu_(x) { + var $x = convertToTensor(x, 'x', 'elu'); + var grad = function (dy, saved) { + var y = saved[0]; + return { + $x: function () { + return ENGINE.runKernelFunc(function (backend) { return backend.eluDer(dy, y); }, { dy: dy, y: y }); + } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var y = backend.elu($x); + save([y]); + return y; + }, { $x: $x }, grad); + } + /** + * Computes scaled exponential linear element-wise. + * + * `x < 0 ? scale * alpha * (exp(x) - 1) : x` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.selu().print(); // or tf.selu(x) + * ``` + * @param x The input tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function selu_(x) { + var $x = convertToTensor(x, 'x', 'selu'); + var grad = function (dy, saved) { + var $x = saved[0]; + return { + $x: function () { + var mask = $x.greater(scalar(0)); + var scaleAlpha = scalar(SELU_SCALEALPHA); + var scale = scalar(SELU_SCALE); + var greaterThanZeroDer = dy.mul(scale); + var lessEqualZeroDer = dy.mul(scaleAlpha).mul($x.toFloat().exp()); + return where(mask, greaterThanZeroDer, lessEqualZeroDer); + } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.selu($x); + save([$x]); + return res; + }, { $x: $x }, grad); + } + /** + * Computes leaky rectified linear element-wise. + * + * See + * [http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf]( + * http://web.stanford.edu/~awni/papers/relu_hybrid_icml2013_final.pdf) + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * + * x.leakyRelu(0.1).print(); // or tf.leakyRelu(x, 0.1) + * ``` + * @param x The input tensor. + * @param alpha The scaling factor for negative values, defaults to 0.2. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function leakyRelu_(x, alpha) { + if (alpha === void 0) { alpha = 0.2; } + var $x = convertToTensor(x, 'x', 'leakyRelu'); + return maximum(scalar(alpha).mul($x), $x); + } + /** + * Computes leaky rectified linear element-wise with parametric alphas. + * + * `x < 0 ? alpha * x : f(x) = x` + * + * ```js + * const x = tf.tensor1d([-1, 2, -3, 4]); + * const alpha = tf.scalar(0.1); + * + * x.prelu(alpha).print(); // or tf.prelu(x, alpha) + * ``` + * @param x The input tensor. + * @param alpha Scaling factor for negative values. + */ + /** @doc {heading: 'Operations', subheading: 'Basic math'} */ + function prelu_(x, alpha) { + var $x = convertToTensor(x, 'x', 'prelu'); + var $alpha = convertToTensor(alpha, 'alpha', 'prelu'); + var grad = function (dy, saved) { + var $x = saved[0], $alpha = saved[1]; + var mask = $x.greater(0); + return { + x: function () { return where(mask, dy, dy.mul($alpha)); }, + alpha: function () { + var res = where(mask, zerosLike(dy), dy.mul($x)); + var reduceAxes = getReductionAxes($alpha.shape, dy.shape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape($alpha.shape); + } + }; + }; + return ENGINE.runKernelFunc(function (backend, save) { + var res = backend.prelu($x, $alpha); + save([$x, $alpha]); + return res; + }, { x: $x, alpha: $alpha }, grad, 'Prelu'); + } + var elu = op({ elu_: elu_ }); + var leakyRelu = op({ leakyRelu_: leakyRelu_ }); + var prelu = op({ prelu_: prelu_ }); + var relu = op({ relu_: relu_ }); + var relu6 = op({ relu6_: relu6_ }); + var selu = op({ selu_: selu_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Transposes the `tf.Tensor`. Permutes the dimensions according to `perm`. + * + * The returned `tf.Tensor`'s dimension `i` will correspond to the input + * dimension `perm[i]`. If `perm` is not given, it is set to `[n-1...0]`, + * where `n` is the rank of the input `tf.Tensor`. Hence by default, this + * operation performs a regular matrix transpose on 2-D input `tf.Tensor`s. + * + * ```js + * const a = tf.tensor2d([1, 2, 3, 4, 5, 6], [2, 3]); + * + * a.transpose().print(); // or tf.transpose(a) + * ``` + * + * @param x The tensor to transpose. + * @param perm The permutation of the dimensions of a. + */ + /** @doc {heading: 'Operations', subheading: 'Matrices'} */ + function transpose_(x, perm) { + var $x = convertToTensor(x, 'x', 'transpose'); + if (perm == null) { + perm = $x.shape.map(function (s, i) { return i; }).reverse(); + } + assert($x.rank === perm.length, function () { return "Error in transpose: rank of input " + $x.rank + " " + + ("must match length of perm " + perm + "."); }); + perm.forEach(function (axis) { + assert(axis >= 0 && axis < $x.rank, function () { return "All entries in 'perm' must be between 0 and " + ($x.rank - 1) + + (" but got " + perm); }); + }); + if ($x.rank <= 1) { + return $x.clone(); + } + var der = function (dy) { + var undoPerm = getUndoAxesPermutation(perm); + return { x: function () { return dy.transpose(undoPerm); } }; + }; + var attrs = { perm: perm }; + return ENGINE.runKernelFunc(function (backend) { return backend.transpose($x, perm); }, { x: $x }, der, 'Transpose', attrs); + } + var transpose = op({ transpose_: transpose_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Normalizes the activation of a local neighborhood across or within + * channels. + * + * @param x The input tensor. The 4-D input tensor is treated as a 3-D array + * of 1D vectors (along the last dimension), and each vector is + * normalized independently. + * @param depthRadius The number of adjacent channels in the 1D normalization + * window. + * @param bias A constant bias term for the basis. + * @param alpha A scale factor, usually positive. + * @param beta An exponent. + */ + /** @doc {heading: 'Operations', subheading: 'Normalization'} */ + function localResponseNormalization_(x, depthRadius, bias, alpha, beta) { + if (depthRadius === void 0) { depthRadius = 5; } + if (bias === void 0) { bias = 1; } + if (alpha === void 0) { alpha = 1; } + if (beta === void 0) { beta = 0.5; } + var $x = convertToTensor(x, 'x', 'localResponseNormalization'); + assert($x.rank === 4 || $x.rank === 3, function () { return "Error in localResponseNormalization: x must be rank 3 or 4 but got\n rank " + $x.rank + "."; }); + assert(isInt(depthRadius), function () { return "Error in localResponseNormalization: depthRadius must be an " + + ("integer but got depthRadius " + depthRadius + "."); }); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + var backward = function (dy, saved) { + var x4D = saved[0], y = saved[1]; + return { + x4D: function () { return ENGINE.runKernelFunc(function (backend) { return backend.LRNGrad(dy, x4D, y, depthRadius, bias, alpha, beta); }, {}); } + }; + }; + var res = ENGINE.runKernelFunc(function (backend, save) { + var y = backend.localResponseNormalization4D(x4D, depthRadius, bias, alpha, beta); + save([x4D, y]); + return y; + }, { x4D: x4D }, backward); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + else { + return res; + } + } + var localResponseNormalization = op({ localResponseNormalization_: localResponseNormalization_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the norm of scalar, vectors, and matrices. + * This function can compute several different vector norms (the 1-norm, the + * Euclidean or 2-norm, the inf-norm, and in general the p-norm for p > 0) + * and matrix norms (Frobenius, 1-norm, and inf-norm). + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * + * x.norm().print(); // or tf.norm(x) + * ``` + * + * @param x The input array. + * @param ord Optional. Order of the norm. Supported norm types are + * following: + * + * | ord | norm for matrices | norm for vectors + * |------------|---------------------------|--------------------- + * |'euclidean' |Frobenius norm |2-norm + * |'fro' |Frobenius norm | + * |Infinity |max(sum(abs(x), axis=1)) |max(abs(x)) + * |-Infinity |min(sum(abs(x), axis=1)) |min(abs(x)) + * |1 |max(sum(abs(x), axis=0)) |sum(abs(x)) + * |2 | |sum(abs(x)^2)^1/2* + * + * @param axis Optional. If axis is null (the default), the input is + * considered a vector and a single vector norm is computed over the entire + * set of values in the Tensor, i.e. norm(x, ord) is equivalent + * to norm(x.reshape([-1]), ord). If axis is a integer, the input + * is considered a batch of vectors, and axis determines the axis in x + * over which to compute vector norms. If axis is a 2-tuple of integer it is + * considered a batch of matrices and axis determines the axes in NDArray + * over which to compute a matrix norm. + * @param keepDims Optional. If true, the norm have the same dimensionality + * as the input. + */ + /** @doc {heading: 'Operations', subheading: 'Matrices'} */ + function norm_(x, ord, axis, keepDims) { + if (ord === void 0) { ord = 'euclidean'; } + if (axis === void 0) { axis = null; } + if (keepDims === void 0) { keepDims = false; } + x = convertToTensor(x, 'x', 'norm'); + var norm = normImpl(x, ord, axis); + var keepDimsShape = norm.shape; + if (keepDims) { + var axes = parseAxisParam(axis, x.shape); + keepDimsShape = expandShapeToKeepDim(norm.shape, axes); + } + return norm.reshape(keepDimsShape); + } + function normImpl(x, p, axis) { + if (axis === void 0) { axis = null; } + if (x.rank === 0) { + return x.abs(); + } + // consider vector when no axis is specified + if (x.rank !== 1 && axis === null) { + return normImpl(x.reshape([-1]), p, axis); + } + // vector + if (x.rank === 1 || typeof axis === 'number' || + Array.isArray(axis) && axis.length === 1) { + if (p === 1) { + return x.abs().sum(axis); + } + if (p === Infinity) { + return x.abs().max(axis); + } + if (p === -Infinity) { + return x.abs().min(axis); + } + if (p === 'euclidean' || p === 2) { + // norm(x, 2) = sum(abs(xi) ^ 2) ^ 1/2 + return x.abs().pow(scalar(2, 'int32')).sum(axis).sqrt(); + } + throw new Error("Error in norm: invalid ord value: " + p); + } + // matrix (assumption axis[0] < axis[1]) + if (Array.isArray(axis) && axis.length === 2) { + if (p === 1) { + return x.abs().sum(axis[0]).max(axis[1] - 1); + } + if (p === Infinity) { + return x.abs().sum(axis[1]).max(axis[0]); + } + if (p === -Infinity) { + return x.abs().sum(axis[1]).min(axis[0]); + } + if (p === 'fro' || p === 'euclidean') { + // norm(x) = sqrt(sum(pow(x, 2))) + return x.square().sum(axis).sqrt(); + } + throw new Error("Error in norm: invalid ord value: " + p); + } + throw new Error("Error in norm: invalid axis: " + axis); + } + var norm = op({ norm_: norm_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the next states and outputs of a stack of LSTMCells. + * + * Each cell output is used as input to the next cell. + * + * Returns `[cellState, cellOutput]`. + * + * Derived from tf.contrib.rn.MultiRNNCell. + * + * @param lstmCells Array of LSTMCell functions. + * @param data The input to the cell. + * @param c Array of previous cell states. + * @param h Array of previous cell outputs. + */ + /** @doc {heading: 'Operations', subheading: 'RNN'} */ + function multiRNNCell_(lstmCells, data, c, h) { + var $data = convertToTensor(data, 'data', 'multiRNNCell'); + var $c = convertToTensorArray(c, 'c', 'multiRNNCell'); + var $h = convertToTensorArray(h, 'h', 'multiRNNCell'); + var input = $data; + var newStates = []; + for (var i = 0; i < lstmCells.length; i++) { + var output = lstmCells[i](input, $c[i], $h[i]); + newStates.push(output[0]); + newStates.push(output[1]); + input = output[1]; + } + var newC = []; + var newH = []; + for (var i = 0; i < newStates.length; i += 2) { + newC.push(newStates[i]); + newH.push(newStates[i + 1]); + } + return [newC, newH]; + } + /** + * Computes the next state and output of a BasicLSTMCell. + * + * Returns `[newC, newH]`. + * + * Derived from tf.contrib.rnn.BasicLSTMCell. + * + * @param forgetBias Forget bias for the cell. + * @param lstmKernel The weights for the cell. + * @param lstmBias The bias for the cell. + * @param data The input to the cell. + * @param c Previous cell state. + * @param h Previous cell output. + */ + /** @doc {heading: 'Operations', subheading: 'RNN'} */ + function basicLSTMCell_(forgetBias, lstmKernel, lstmBias, data, c, h) { + var $forgetBias = convertToTensor(forgetBias, 'forgetBias', 'basicLSTMCell'); + var $lstmKernel = convertToTensor(lstmKernel, 'lstmKernel', 'basicLSTMCell'); + var $lstmBias = convertToTensor(lstmBias, 'lstmBias', 'basicLSTMCell'); + var $data = convertToTensor(data, 'data', 'basicLSTMCell'); + var $c = convertToTensor(c, 'c', 'basicLSTMCell'); + var $h = convertToTensor(h, 'h', 'basicLSTMCell'); + var combined = $data.concat($h, 1); + var weighted = combined.matMul($lstmKernel); + var res = weighted.add($lstmBias); + // i = input_gate, j = new_input, f = forget_gate, o = output_gate + var batchSize = res.shape[0]; + var sliceCols = res.shape[1] / 4; + var sliceSize = [batchSize, sliceCols]; + var i = res.slice([0, 0], sliceSize); + var j = res.slice([0, sliceCols], sliceSize); + var f = res.slice([0, sliceCols * 2], sliceSize); + var o = res.slice([0, sliceCols * 3], sliceSize); + var newC = i.sigmoid().mulStrict(j.tanh()).addStrict($c.mulStrict($forgetBias.add(f).sigmoid())); + var newH = newC.tanh().mulStrict(o.sigmoid()); + return [newC, newH]; + } + var basicLSTMCell = op({ basicLSTMCell_: basicLSTMCell_ }); + var multiRNNCell = op({ multiRNNCell_: multiRNNCell_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Compute the moving average of a variable. + * + * Without zeroDebias, the moving average operation is defined by: + * `v += delta` + * where + * `delta = (1 - decay) * (x - v)` + * + * With zeroDebias (default), the `delta` term is scaled to debias the + * effect of the (assumed) zero-initialization of `v`. + * `delta /= (1 - decay ^ step)` + * + * For more details on the zero-debiasing algorithm, see: + * https://arxiv.org/abs/1412.6980 + * + * Note that this function is completely stateless and does not keep track of + * step count. The step count needs to be maintained by the caller and passed + * in as `step`. + * + * @param v The current moving average value. + * @param x New input value, must have the same shape and dtype as `v`. + * @param decay The decay factor. Typical values are 0.95 and 0.99. + * @param step Step count. + * @param zeroDebias: Whether zeroDebias is to be performed (default: `true`). + * @returns The new moving average value. + */ + /** @doc {heading: 'Operations', subheading: 'Moving Average'} */ + function movingAverage_(v, x, decay, step, zeroDebias) { + if (zeroDebias === void 0) { zeroDebias = true; } + var $v = convertToTensor(v, 'v', 'movingAverage'); + var $x = convertToTensor(x, 'x', 'movingAverage'); + var $decay = convertToTensor(decay, 'decay', 'movingAverage'); + assertTypesMatch($v, $x); + assert(arraysEqual($v.shape, $x.shape), function () { return 'Shape mismatch in v and x'; }); + var one = scalar(1); + var oneMinusDecay = one.sub($decay); + var update = $x.sub($v).mul(oneMinusDecay); + if (zeroDebias) { + assert(step != null, function () { return 'When using zeroDebias: true, step is required.'; }); + var $step = convertToTensor(step, 'step', 'movingAverage'); + update = update.div(one.sub(pow($decay, $step))); + } + return $v.add(update); + } + var movingAverage = op({ movingAverage_: movingAverage_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Extracts a strided slice of a tensor. + * + * Roughly speaking, this op extracts a slice of size (end-begin)/stride from + * the given input tensor (x). Starting at the location specified by begin the + * slice continues by adding stride to the index until all dimensions are not + * less than end. Note that a stride can be negative, which causes a reverse + * slice. + * + * ```js + * const t = tf.tensor3d([1, 1, 1 ,2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6], + * [3, 2, 3]); + * t.stridedSlice([1, 0, 0], [2, 1, 3], [1, 1, 1]).print() // [[[3, 3, 3]]] + * t.stridedSlice([1, 0, 0], [2, 2, 3], [1, 1, 1]).print() // [[[3, 3, 3], + * // [4, 4, 4]]] + * t.stridedSlice([1, -1, 0], [2, -3, 3], [1, -1, 1]).print() // [[[4, 4, 4], + * // [3, 3, 3]]] + * ``` + * + * @param x The tensor to stride slice. + * @param begin The coordinates to start the slice from. + * @param end: The coordinates to end the slice at. + * @param strides: The size of the slice. + * @param beginMask: If the ith bit of beginMask is set, begin[i] is ignored + * and the fullest possible range in that dimension is used instead. + * @param endMask: If the ith bit of endMask is set, end[i] is ignored + * and the fullest possible range in that dimension is used instead. + * @param shrinkAxisMask: a bitmask where bit i implies that + * the ith specification should shrink the dimensionality. begin and end must + * imply a slice of size 1 in the dimension. + */ + /** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */ + function stridedSlice_(x, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask) { + if (beginMask === void 0) { beginMask = 0; } + if (endMask === void 0) { endMask = 0; } + if (ellipsisMask === void 0) { ellipsisMask = 0; } + if (newAxisMask === void 0) { newAxisMask = 0; } + if (shrinkAxisMask === void 0) { shrinkAxisMask = 0; } + if (strides == null) { + strides = new Array(begin.length); + } + if (ellipsisMask !== 0) { + throw new Error('ellipsis mask is not yet supported'); + } + var $x = convertToTensor(x, 'x', 'stridedSlice'); + // Expand the dims of x based on the newAxisMask. + var expandAxes = maskToAxes(newAxisMask); + var newShape = $x.shape.slice(); + expandAxes.forEach(function (axis) { + begin[axis] = 0; + end[axis] = 1; + newShape.splice(axis, 0, 1); + }); + $x = $x.reshape(newShape); + // Normalize the start, end and strides. + for (var axis = 0; axis < $x.rank; axis++) { + begin[axis] = startForAxis(beginMask, begin, strides, $x.shape, axis); + end[axis] = stopForAxis(endMask, end, strides, $x.shape, axis); + strides[axis] = strides[axis] || 1; + } + var shrinkAxes = maskToAxes(shrinkAxisMask); + // Adjust the ends based on the shrink mask. + shrinkAxes.forEach(function (axis) { + end[axis] = begin[axis] + 1; + strides[axis] = 1; + }); + // Figure out the output shape. + var size = computeOutShape$2(begin, end, strides); + // Remove the axes based on shrinkMask. + var outShape = size.filter(function (_, axis) { return shrinkAxes.indexOf(axis) === -1; }); + var nonStrided = strides.every(function (v) { return v === 1; }); + if (nonStrided) { + return slice($x, begin, size).reshape(outShape); + } + var res = ENGINE.runKernelFunc(function (backend) { return backend.stridedSlice($x, begin, end, strides); }, { $x: $x }); + return res.reshape(outShape); + } + var stridedSlice = op({ stridedSlice_: stridedSlice_ }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Finds the values and indices of the `k` largest entries along the last + * dimension. + * + * If the input is a vector (rank=1), finds the k largest entries in the vector + * and outputs their values and indices as vectors. Thus values[j] is the j-th + * largest entry in input, and its index is indices[j]. + * For higher rank inputs, computes the top k entries along the last dimension. + * + * If two elements are equal, the lower-index element appears first. + * + * ```js + * const a = tf.tensor2d([[1, 5], [4, 3]]); + * const {values, indices} = tf.topk(a); + * values.print(); + * indices.print(); + * ``` + * @param x 1-D or higher `tf.Tensor` with last dimension being at least `k`. + * @param k Number of top elements to look for along the last dimension. + * @param sorted If true, the resulting `k` elements will be sorted by the + * values in descending order. + */ + /** @doc {heading: 'Operations', subheading: 'Evaluation'} */ + function topk_(x, k, sorted) { + if (k === void 0) { k = 1; } + if (sorted === void 0) { sorted = true; } + var $x = convertToTensor(x, 'x', 'topk'); + if ($x.rank === 0) { + throw new Error('topk() expects the input to be of rank 1 or higher'); + } + var lastDim = $x.shape[$x.shape.length - 1]; + if (k > lastDim) { + throw new Error("'k' passed to topk() must be <= the last dimension (" + lastDim + ") " + + ("but got " + k)); + } + var _a = ENGINE.runKernelFunc(function (b) { return b.topk($x, k, sorted); }, { $x: $x }), values = _a[0], indices = _a[1]; + return { values: values, indices: indices }; + } + var topk = op({ topk_: topk_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Creates a new tensor by applying sparse updates to individual + * values or slices within a zero tensor of the given shape tensor according to + * indices. This operator is the inverse of the `tf.gatherND` operator which + * extracts values or slices from a given tensor. + * + * ```js + * const indices = tf.tensor2d([4, 3, 1, 7], [4, 1], 'int32'); + * const updates = tf.tensor1d([9, 10, 11, 12]); + * const shape = [8]; + * tf.scatterND(indices, updates, shape).print() //[0, 11, 0, 10, 9, 0, 0, 12] + * ``` + * + * @param indices The tensor contains the indices into the output tensor. + * @param updates The tensor contains the value for the indices. + * @param shape: The shape of the output tensor. + */ + /** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */ + function scatterND_(indices, updates, shape) { + var $indices = convertToTensor(indices, 'indices', 'scatterND', 'int32'); + var $updates = convertToTensor(updates, 'updates', 'scatterND'); + validateInput($updates, $indices, shape); + return ENGINE.runKernelFunc(function (backend) { return backend.scatterND($indices, $updates, shape); }, { $indices: $indices, $updates: $updates }); + } + var scatterND = op({ scatterND_: scatterND_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Fast Fourier transform. + * + * Computes the 1-dimensional discrete Fourier transform over the inner-most + * dimension of input. + * + * ```js + * const real = tf.tensor1d([1, 2, 3]); + * const imag = tf.tensor1d([1, 2, 3]); + * const x = tf.complex(real, imag); + * + * x.fft().print(); // tf.spectral.fft(x).print(); + * ``` + * @param input The complex input to compute an fft over. + */ + /** + * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} + */ + function fft_(input) { + assert(input.dtype === 'complex64', function () { return "The dtype for tf.spectral.fft() must be complex64 " + + ("but got " + input.dtype + "."); }); + // Collapse all outer dimensions to a single batch dimension. + var innerDimensionSize = input.shape[input.shape.length - 1]; + var batch = input.size / innerDimensionSize; + var input2D = input.as2D(batch, innerDimensionSize); + var ret = ENGINE.runKernelFunc(function (backend) { return backend.fft(input2D); }, { input: input }); + return ret.reshape(input.shape); + } + /** + * Inverse fast Fourier transform. + * + * Computes the inverse 1-dimensional discrete Fourier transform over the + * inner-most dimension of input. + * + * ```js + * const real = tf.tensor1d([1, 2, 3]); + * const imag = tf.tensor1d([1, 2, 3]); + * const x = tf.complex(real, imag); + * + * x.ifft().print(); // tf.spectral.ifft(x).print(); + * ``` + * @param input The complex input to compute an ifft over. + */ + /** + * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} + */ + function ifft_(input) { + assert(input.dtype === 'complex64', function () { return "The dtype for tf.spectral.ifft() must be complex64 " + + ("but got " + input.dtype + "."); }); + // Collapse all outer dimensions to a single batch dimension. + var innerDimensionSize = input.shape[input.shape.length - 1]; + var batch = input.size / innerDimensionSize; + var input2D = input.as2D(batch, innerDimensionSize); + var ret = ENGINE.runKernelFunc(function (backend) { return backend.ifft(input2D); }, { input: input }); + return ret.reshape(input.shape); + } + /** + * Real value input fast Fourier transform. + * + * Computes the 1-dimensional discrete Fourier transform over the + * inner-most dimension of the real input. + * + * ```js + * const real = tf.tensor1d([1, 2, 3]); + * + * real.rfft().print(); + * ``` + * @param input The real value input to compute an rfft over. + */ + /** + * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} + */ + function rfft_(input, fftLength) { + assert(input.dtype === 'float32', function () { return "The dtype for rfft() must be real value but got " + input.dtype; }); + var innerDimensionSize = input.shape[input.shape.length - 1]; + var batch = input.size / innerDimensionSize; + var adjustedInput; + if (fftLength != null && fftLength < innerDimensionSize) { + // Need to crop + var begin = input.shape.map(function (v) { return 0; }); + var size = input.shape.map(function (v) { return v; }); + size[input.shape.length - 1] = fftLength; + adjustedInput = input.slice(begin, size); + innerDimensionSize = fftLength; + } + else if (fftLength != null && fftLength > innerDimensionSize) { + // Need to pad with zeros + var zerosShape = input.shape.map(function (v) { return v; }); + zerosShape[input.shape.length - 1] = fftLength - innerDimensionSize; + adjustedInput = input.concat(zeros(zerosShape), input.shape.length - 1); + innerDimensionSize = fftLength; + } + else { + adjustedInput = input; + } + // Complement the input with zero imaginary numbers. + var zerosInput = adjustedInput.zerosLike(); + var complexInput = complex(adjustedInput, zerosInput).as2D(batch, innerDimensionSize); + var ret = fft(complexInput); + // Exclude complex conjugations. These conjugations are put symmetrically. + var half = Math.floor(innerDimensionSize / 2) + 1; + var realValues = real(ret); + var imagValues = imag(ret); + var realComplexConjugate = realValues.split([half, innerDimensionSize - half], realValues.shape.length - 1); + var imagComplexConjugate = imagValues.split([half, innerDimensionSize - half], imagValues.shape.length - 1); + var outputShape = adjustedInput.shape.slice(); + outputShape[adjustedInput.shape.length - 1] = half; + return complex(realComplexConjugate[0], imagComplexConjugate[0]) + .reshape(outputShape); + } + /** + * Inversed real value input fast Fourier transform. + * + * Computes the 1-dimensional inversed discrete Fourier transform over the + * inner-most dimension of the real input. + * + * ```js + * const real = tf.tensor1d([1, 2, 3]); + * const imag = tf.tensor1d([0, 0, 0]); + * const x = tf.complex(real, imag); + * + * x.irfft().print(); + * ``` + * @param input The real value input to compute an irfft over. + */ + /** + * @doc {heading: 'Operations', subheading: 'Spectral', namespace: 'spectral'} + */ + function irfft_(input) { + var innerDimensionSize = input.shape[input.shape.length - 1]; + var batch = input.size / innerDimensionSize; + if (innerDimensionSize <= 2) { + var complexInput = input.as2D(batch, innerDimensionSize); + var ret = ifft(complexInput); + return real(ret); + } + else { + // The length of unique components of the DFT of a real-valued signal + // is 2 * (input_len - 1) + var outputShape = [batch, 2 * (innerDimensionSize - 1)]; + var realInput = real(input).as2D(batch, innerDimensionSize); + var imagInput = imag(input).as2D(batch, innerDimensionSize); + var realConjugate = realInput.slice([0, 1], [batch, innerDimensionSize - 2]).reverse(1); + var imagConjugate = imagInput.slice([0, 1], [batch, innerDimensionSize - 2]) + .reverse(1) + .mul(scalar(-1)); + var r = realInput.concat(realConjugate, 1); + var i = imagInput.concat(imagConjugate, 1); + var complexInput = complex(r, i).as2D(outputShape[0], outputShape[1]); + var ret = ifft(complexInput); + return real(ret); + } + } + var fft = op({ fft_: fft_ }); + var ifft = op({ ifft_: ifft_ }); + var rfft = op({ rfft_: rfft_ }); + var irfft = op({ irfft_: irfft_ }); + + var spectral_ops = /*#__PURE__*/Object.freeze({ + fft: fft, + ifft: ifft, + rfft: rfft, + irfft: irfft + }); + + /** + * Validate sparseToDense inputs. + * + * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32. + * sparseIndices[i] contains the complete index where sparseValues[i] will be + * placed. + * @param sparseValues A 0-D or 1-D Tensor. Values + * corresponding to each row of sparseIndices, or a scalar value to be used for + * all sparse indices. + * @param outputShape number[]. Shape of the dense output tensor. + * @param validateIndices boolean. indice validation is not supported, error + * will be thrown if it is set. + */ + function validateInput$1(sparseIndices, sparseValues, outputShape, defaultValues) { + if (sparseIndices.dtype !== 'int32') { + throw new Error('tf.sparseToDense() expects the indices to be int32 type,' + + (" but the dtype was " + sparseIndices.dtype + ".")); + } + if (sparseIndices.rank > 2) { + throw new Error('sparseIndices should be a scalar, vector, or matrix,' + + (" but got shape " + sparseIndices.shape + ".")); + } + var numElems = sparseIndices.rank > 0 ? sparseIndices.shape[0] : 1; + var numDims = sparseIndices.rank > 1 ? sparseIndices.shape[1] : 1; + if (outputShape.length !== numDims) { + throw new Error('outputShape has incorrect number of elements:,' + + (" " + outputShape.length + ", should be: " + numDims + ".")); + } + var numValues = sparseValues.size; + if (!(sparseValues.rank === 0 || + sparseValues.rank === 1 && numValues === numElems)) { + throw new Error('sparseValues has incorrect shape ' + + (sparseValues.shape + ", should be [] or [" + numElems + "]")); + } + if (sparseValues.dtype !== defaultValues.dtype) { + throw new Error('sparseValues.dtype must match defaultValues.dtype'); + } + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Converts a sparse representation into a dense tensor. + * + * Builds an array dense with shape outputShape such that: + * + * // If sparseIndices is scalar + * dense[i] = (i == sparseIndices ? sparseValues : defaultValue) + * + * // If sparseIndices is a vector, then for each i + * dense[sparseIndices[i]] = sparseValues[i] + * + * // If sparseIndices is an n by d matrix, then for each i in [0, n) + * dense[sparseIndices[i][0], ..., sparseIndices[i][d-1]] = sparseValues[i] + * All other values in dense are set to defaultValue. If sparseValues is a + * scalar, all sparse indices are set to this single value. + * + * If indices are repeated the final value is summed over all values for those + * indices. + * + * ```js + * const indices = tf.tensor1d([4, 5, 6, 1, 2, 3], 'int32'); + * const values = tf.tensor1d([10, 11, 12, 13, 14, 15], 'float32'); + * const shape = [8]; + * tf.sparseToDense(indices, values, shape).print(); + * ``` + * + * @param sparseIndices A 0-D, 1-D, or 2-D Tensor of type int32. + * sparseIndices[i] contains the complete index where sparseValues[i] will be + * placed. + * @param sparseValues A 0-D or 1-D Tensor. Values + * corresponding to each row of sparseIndices, or a scalar value to be used for + * all sparse indices. + * @param outputShape Shape of the dense output tensor. the type is inferred. + * @param defaultValue Scalar. Value to set for indices not specified in + * sparseIndices. Defaults to zero. + */ + /** @doc {heading: 'Operations', subheading: 'Normalization'} */ + function sparseToDense_(sparseIndices, sparseValues, outputShape, defaultValue) { + if (defaultValue === void 0) { defaultValue = 0; } + var $sparseIndices = convertToTensor(sparseIndices, 'sparseIndices', 'sparseToDense', 'int32'); + var $sparseValues = convertToTensor(sparseValues, 'sparseValues', 'sparseToDense'); + var $defaultValue = convertToTensor(defaultValue, 'defaultValue', 'sparseToDense', $sparseValues.dtype); + validateInput$1($sparseIndices, $sparseValues, outputShape, $defaultValue); + return ENGINE.runKernelFunc(function (backend) { return backend.sparseToDense($sparseIndices, $sparseValues, outputShape, $defaultValue); }, { $sparseIndices: $sparseIndices, $sparseValues: $sparseValues, $defaultValue: $defaultValue }); + } + var sparseToDense = op({ sparseToDense_: sparseToDense_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Gather slices from input tensor into a Tensor with shape specified by + * `indices`. + * + * `indices` is an K-dimensional integer tensor, best thought of as a + * (K-1)-dimensional tensor of indices into input, where each element defines a + * slice of input: + * output[\\(i_0, ..., i_{K-2}\\)] = input[indices[\\(i_0, ..., i_{K-2}\\)]] + * + * Whereas in `tf.gather`, `indices` defines slices into the first dimension of + * input, in `tf.gatherND`, `indices` defines slices into the first N dimensions + * of input, where N = indices.shape[-1]. + * + * The last dimension of indices can be at most the rank of input: + * indices.shape[-1] <= input.rank + * + * The last dimension of `indices` corresponds to elements + * (if indices.shape[-1] == input.rank) or slices + * (if indices.shape[-1] < input.rank) along dimension indices.shape[-1] of + * input. + * The output tensor has shape + * indices.shape[:-1] + input.shape[indices.shape[-1]:] + * + * Note that on CPU, if an out of bound index is found, an error is returned. On + * GPU, if an out of bound index is found, a 0 is stored in the corresponding + * output value. + * + * ```js + * const indices = tf.tensor2d([0, 1, 1, 0], [2,2], 'int32'); + * const input = tf.tensor2d([9, 10, 11, 12], [2, 2]); + * tf.gatherND(input, indices).print() // [10, 11] + * ``` + * + * @param x The tensor from which to gather values. + * @param indices Index tensor, must be of type int32. + */ + /** @doc {heading: 'Operations', subheading: 'Slicing and Joining'} */ + function gatherND_(x, indices) { + var $indices = convertToTensor(indices, 'indices', 'gatherND', 'int32'); + var $x = convertToTensor(x, 'x', 'gatherND'); + return ENGINE.runKernelFunc(function (backend) { return backend.gatherND($x, $indices); }, { $x: $x, $indices: $indices }); + } + var gatherND = op({ gatherND_: gatherND_ }); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns a diagonal tensor with a given diagonal values. + * + * Given a diagonal, this operation returns a tensor with the diagonal and + * everything else padded with zeros. + * + * Assume the input has dimensions `[D1,..., Dk]`, then the output is a tensor + * of rank 2k with dimensions `[D1,..., Dk, D1,..., Dk]` + * + * ```js + * const x = tf.tensor1d([1, 2, 3, 4]); + * + * tf.diag(x).print() + * ``` + * ```js + * const x = tf.tensor1d([1, 2, 3, 4, 5, 6, 6, 8], [4, 2]) + * + * tf.diag(x).print() + * ``` + * @param x The input tensor. + */ + function diag_(x) { + var $x = convertToTensor(x, 'x', 'diag').flatten(); + var outShape = x.shape.concat(x.shape); + return ENGINE.runKernelFunc(function (backend) { return backend.diag($x); }, { $x: $x }) + .reshape(outShape); + } + var diag = op({ diag_: diag_ }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Normalize noise shape based on provided tensor and noise shape. + * + * @param x Tensor. + * @param noiseShape The shape for the randomly generated keep/drop flags, as + * an array of numbers. Optional. + * @returns Normalized noise shape. + */ + function getNoiseShape(x, noiseShape) { + if (noiseShape == null) { + return x.shape.slice(); + } + if (arraysEqual(x.shape, noiseShape)) { + return noiseShape; + } + if (x.shape.length === noiseShape.length) { + var newDimension = []; + for (var i = 0; i < x.shape.length; i++) { + if (noiseShape[i] == null && x.shape[i] != null) { + newDimension.push(x.shape[i]); + } + else { + newDimension.push(noiseShape[i]); + } + } + return newDimension; + } + return noiseShape; + } + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes dropout. + * + * ```js + * const x = tf.tensor1d([1, 2, 2, 1]); + * const rate = 0.75; + * const output = tf.dropout(x, rate); + * output.print(); + * ``` + * + * @param x A floating point Tensor or TensorLike. + * @param rate A float in the range [0, 1). The probability that each element + * of x is discarded. + * @param noiseShape An array of numbers of type int32, representing the + * shape for randomly generated keep/drop flags. If the noiseShape has null + * value, it will be automatically replaced with the x's relative dimension + * size. Optional. + * @param seed Used to create random seeds. Optional. + * @returns A Tensor of the same shape of x. + */ + /** @doc {heading: 'Operations', subheading: 'Dropout'} */ + function dropout_(x, rate, noiseShape, seed) { + var $x = convertToTensor(x, 'x', 'dropout'); + assert($x.dtype === 'float32', function () { return "x has to be a floating point tensor since it's going to be " + + ("scaled, but got a " + $x.dtype + " tensor instead."); }); + assert(rate >= 0 && rate < 1, function () { return "rate must be a float in the range [0, 1), but got " + rate + "."; }); + if (rate === 0) { + return x instanceof Tensor ? $x.clone() : $x; + } + var $noiseShape = getNoiseShape($x, noiseShape); + var keepProb = 1 - rate; + var multiplier = randomUniform($noiseShape, 0, 1, 'float32', seed) + .add(keepProb) + .floor() + .div(keepProb); + return $x.mul(multiplier); + } + var dropout = op({ dropout_: dropout_ }); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Generate a Hann window. + * + * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows + * + * ```js + * tf.signal.hannWindow(10).print(); + * ``` + * @param The length of window + */ + /** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ + function hannWindow_(windowLength) { + return cosineWindow(windowLength, 0.5, 0.5); + } + /** + * Generate a hamming window. + * + * See: https://en.wikipedia.org/wiki/Window_function#Hann_and_Hamming_windows + * + * ```js + * tf.signal.hammingWindow(10).print(); + * ``` + * @param The length of window + */ + /** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ + function hammingWindow_(windowLength) { + return cosineWindow(windowLength, 0.54, 0.46); + } + /** + * Expands input into frames of frameLength. + * Slides a window size with frameStep. + * + * ```js + * tf.signal.frame([1, 2, 3], 2, 1).print(); + * ``` + * @param signal The input tensor to be expanded + * @param frameLength Length of each frame + * @param frameStep The frame hop size in samples. + * @param padEnd Whether to pad the end of signal with padValue. + * @param padValue An number to use where the input signal does + * not exist when padEnd is True. + */ + /** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ + function frame_(signal, frameLength, frameStep, padEnd, padValue) { + if (padEnd === void 0) { padEnd = false; } + if (padValue === void 0) { padValue = 0; } + var start = 0; + var output = []; + while (start + frameLength <= signal.size) { + output.push(slice(signal, start, frameLength)); + start += frameStep; + } + if (padEnd) { + while (start < signal.size) { + var padLen = (start + frameLength) - signal.size; + var pad = concat([slice(signal, start, frameLength - padLen), + fill([padLen], padValue)]); + output.push(pad); + start += frameStep; + } + } + if (output.length === 0) { + return tensor2d([], [0, frameLength]); + } + return concat(output).as2D(output.length, frameLength); + } + /** + * Computes the Short-time Fourier Transform of signals + * See: https://en.wikipedia.org/wiki/Short-time_Fourier_transform + * + * ```js + * const input = tf.tensor1d([1, 1, 1, 1, 1]) + * tf.signal.stft(input, 3, 1).print(); + * ``` + * @param signal 1-dimensional real value tensor. + * @param frameLength The window length of samples. + * @param frameStep The number of samples to step. + * @param fftLength The size of the FFT to apply. + * @param windowFn A callable that takes a window length and returns 1-d tensor. + */ + /** + * @doc {heading: 'Operations', subheading: 'Signal', namespace: 'signal'} + */ + function stft_(signal, frameLength, frameStep, fftLength, windowFn) { + if (windowFn === void 0) { windowFn = hannWindow; } + if (fftLength == null) { + fftLength = enclosingPowerOfTwo(frameLength); + } + var framedSignal = frame(signal, frameLength, frameStep); + var windowedSignal = mul(framedSignal, windowFn(frameLength)); + var output = []; + for (var i = 0; i < framedSignal.shape[0]; i++) { + output.push(rfft(windowedSignal.slice([i, 0], [1, frameLength]), fftLength)); + } + return concat(output); + } + function enclosingPowerOfTwo(value) { + // Return 2**N for integer N such that 2**N >= value. + return Math.floor(Math.pow(2, Math.ceil(Math.log(value) / Math.log(2.0)))); + } + function cosineWindow(windowLength, a, b) { + var even = 1 - windowLength % 2; + var newValues = new Float32Array(windowLength); + for (var i = 0; i < windowLength; ++i) { + var cosArg = (2.0 * Math.PI * i) / (windowLength + even - 1); + newValues[i] = a - b * Math.cos(cosArg); + } + return tensor1d(newValues, 'float32'); + } + var hannWindow = op({ hannWindow_: hannWindow_ }); + var hammingWindow = op({ hammingWindow_: hammingWindow_ }); + var frame = op({ frame_: frame_ }); + var stft = op({ stft_: stft_ }); + + var signal_ops = /*#__PURE__*/Object.freeze({ + hannWindow: hannWindow, + hammingWindow: hammingWindow, + frame: frame, + stft: stft + }); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Returns whether the targets are in the top K predictions. + * + * ```js + * const predictions = tf.tensor2d([[20, 10, 40, 30], [30, 50, -20, 10]]); + * const targets = tf.tensor1d([2, 0]); + * const precision = await tf.inTopKAsync(predictions, targets); + * precision.print(); + * ``` + * @param predictions 2-D or higher `tf.Tensor` with last dimension being + * at least `k`. + * @param targets 1-D or higher `tf.Tensor`. + * @param k Optional Number of top elements to look at for computing precision, + * default to 1. + */ + /** @doc {heading: 'Operations', subheading: 'Evaluation'} */ + function inTopKAsync_(predictions, targets, k) { + if (k === void 0) { k = 1; } + return __awaiter(this, void 0, void 0, function () { + var $predictions, $targets, lastDim, predictionsVals, targetsVals, _a, batch, size, precision, b, offset, vals, valAndInd, i, i; + return __generator(this, function (_b) { + switch (_b.label) { + case 0: + $predictions = convertToTensor(predictions, 'predictions', 'inTopK'); + $targets = convertToTensor(targets, 'targets', 'inTopK'); + assert($predictions.rank > 1, function () { return 'inTopK() expects the predictions to be of rank 2 or higher, ' + + ("but got " + $predictions.rank); }); + assert($predictions.rank - 1 === $targets.rank, function () { return "predictions rank should be 1 larger than " + + "targets rank, but got predictions rank " + + ($predictions.rank + " and targets rank " + $targets.rank); }); + assertShapesMatch($predictions.shape.slice(0, $predictions.shape.length - 1), $targets.shape, "predictions's shape should be align with the targets' shape, " + + 'except the last dimension.'); + lastDim = $predictions.shape[$predictions.shape.length - 1]; + assert(k > 0 && k <= lastDim, function () { return "'k' passed to inTopK() must be > 0 && <= the predictions last " + + ("dimension (" + lastDim + "), but got " + k); }); + return [4 /*yield*/, $predictions.data()]; + case 1: + predictionsVals = _b.sent(); + return [4 /*yield*/, $targets.data()]; + case 2: + targetsVals = _b.sent(); + _a = [predictionsVals.length / lastDim, lastDim], batch = _a[0], size = _a[1]; + precision = getTypedArrayFromDType('bool', batch); + for (b = 0; b < batch; b++) { + offset = b * size; + vals = predictionsVals.subarray(offset, offset + size); + valAndInd = []; + for (i = 0; i < vals.length; i++) { + valAndInd.push({ value: vals[i], index: i }); + } + valAndInd.sort(function (a, b) { return b.value - a.value; }); + precision[b] = 0; + for (i = 0; i < k; i++) { + if (valAndInd[i].index === targetsVals[b]) { + precision[b] = 1; + break; + } + } + } + if (predictions !== $predictions) { + $predictions.dispose(); + } + if (targets !== $targets) { + $targets.dispose(); + } + // Output precision has the same shape as targets. + return [2 /*return*/, tensor(precision, $targets.shape, 'bool')]; + } + }); + }); + } + var inTopKAsync = inTopKAsync_; + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + (function (Reduction) { + Reduction[Reduction["NONE"] = 0] = "NONE"; + Reduction[Reduction["MEAN"] = 1] = "MEAN"; + Reduction[Reduction["SUM"] = 2] = "SUM"; + Reduction[Reduction["SUM_BY_NONZERO_WEIGHTS"] = 3] = "SUM_BY_NONZERO_WEIGHTS"; + })(exports.Reduction || (exports.Reduction = {})); + /** + * Computes the weighted loss between two tensors. + * + * @param losses Tensor of shape `[batch_size, d1, ... dN]`. + * @param weights Tensor whose rank is either 0, or the same rank as + * `losses`, and must be broadcastable to `losses` (i.e., all + * dimensions must be either `1`, or the same as the corresponding + * `losses` dimension). + */ + /** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ + function computeWeightedLoss_(losses, weights, reduction) { + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $losses = convertToTensor(losses, 'losses', 'computeWeightedLoss'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'computeWeightedLoss'); + } + var weightedLoss = ($weights == null) ? $losses : $losses.mul($weights); + if (reduction === exports.Reduction.NONE) { + return weightedLoss; + } + if (reduction === exports.Reduction.SUM) { + return weightedLoss.sum(); + } + if (reduction === exports.Reduction.MEAN) { + if ($weights == null) { + return weightedLoss.mean(); + } + else { + var broadcastFactor = $losses.size / $weights.size; + var result = weightedLoss.sum().div($weights.sum()); + return broadcastFactor > 1 ? result.div(scalar(broadcastFactor)) : + result; + } + } + if (reduction === exports.Reduction.SUM_BY_NONZERO_WEIGHTS) { + if ($weights == null) { + return weightedLoss.sum().div(scalar($losses.size)); + } + else { + var broadcastedWeights = $weights.mul(ones$1($losses.shape)); + var numNonZeros = broadcastedWeights.notEqual(scalar(0)).sum().toFloat(); + return weightedLoss.sum().div(numNonZeros); + } + } + throw Error("Unknown reduction: " + reduction); + } + /** + * Computes the absolute difference loss between two tensors. + * + * @param labels The ground truth output tensor, same dimensions as + * 'predictions'. + * @param predictions The predicted outputs. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + /** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ + function absoluteDifference_(labels, predictions, weights, reduction) { + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $labels = convertToTensor(labels, 'labels', 'absoluteDifference'); + var $predictions = convertToTensor(predictions, 'predictions', 'absoluteDifference'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'absoluteDifference'); + } + assertShapesMatch($labels.shape, $predictions.shape, 'Error in absoluteDifference: '); + var losses = $labels.sub($predictions).abs(); + return computeWeightedLoss(losses, $weights, reduction); + } + /** + * Computes the mean squared error between two tensors. + * + * @param labels The ground truth output tensor, same dimensions as + * 'predictions'. + * @param predictions The predicted outputs. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + /** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ + function meanSquaredError_(labels, predictions, weights, reduction) { + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $labels = convertToTensor(labels, 'labels', 'meanSquaredError'); + var $predictions = convertToTensor(predictions, 'predictions', 'meanSquaredError'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'meanSquaredError'); + } + assertShapesMatch($labels.shape, $predictions.shape, 'Error in meanSquaredError: '); + var losses = $labels.squaredDifference($predictions); + return computeWeightedLoss(losses, $weights, reduction); + } + /** + * Computes the cosine distance loss between two tensors. + * + * @param labels The ground truth output tensor, same dimensions as + * 'predictions'. + * @param predictions The predicted outputs. + * @param axis The dimension along which the cosine distance is computed. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + /** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ + function cosineDistance_(labels, predictions, axis, weights, reduction) { + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $labels = convertToTensor(labels, 'labels', 'cosineDistance'); + var $predictions = convertToTensor(predictions, 'predictions', 'cosineDistance'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'cosineDistance'); + } + assertShapesMatch($labels.shape, $predictions.shape, 'Error in cosineDistance: '); + var one = scalar(1); + var losses = one.sub($labels.mul($predictions).sum(axis, true)); + return computeWeightedLoss(losses, $weights, reduction); + } + /** + * Computes the Hinge loss between two tensors. + * + * @param labels The ground truth output tensor, same dimensions as + * 'predictions'. + * @param predictions The predicted outputs. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + /** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ + function hingeLoss_(labels, predictions, weights, reduction) { + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $labels = convertToTensor(labels, 'labels', 'hingeLoss'); + var $predictions = convertToTensor(predictions, 'predictions', 'hingeLoss'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'hingeLoss'); + } + assertShapesMatch($labels.shape, $predictions.shape, 'Error in hingeLoss: '); + var one = scalar(1); + // Convert binary labels to (-1, 1) + $labels = scalar(2).mul($labels).sub(one); + var losses = one.sub($labels.mul($predictions)).relu(); + return computeWeightedLoss(losses, $weights, reduction); + } + /** + * Computes the log loss between two tensors. + * + * @param labels The ground truth output tensor, same dimensions as + * 'predictions'. + * @param predictions The predicted outputs. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param epsilon A small increment to avoid taking log of zero + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + /** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ + function logLoss_(labels, predictions, weights, epsilon, reduction) { + if (epsilon === void 0) { epsilon = 1e-7; } + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $labels = convertToTensor(labels, 'labels', 'logLoss'); + var $predictions = convertToTensor(predictions, 'predictions', 'logLoss'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'logLoss'); + } + assertShapesMatch($labels.shape, $predictions.shape, 'Error in logLoss: '); + var one = scalar(1); + var epsilonScalar = scalar(epsilon); + var losses = $labels.mul($predictions.add(epsilonScalar).log()) + .neg() + .sub(one.sub($labels).mul(one.sub($predictions).add(epsilonScalar).log())); + return computeWeightedLoss(losses, $weights, reduction); + } + function sigmoidCrossEntropyWithLogits_(labels, logits) { + var $labels = convertToTensor(labels, 'labels', 'sigmoidCrossEntropyWithLogits'); + var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropyWithLogits'); + assertShapesMatch($labels.shape, $logits.shape, 'Error in sigmoidCrossEntropyWithLogits: '); + /** + * Implementation Details: + * + * For brevity, let `x = logits`, `z = labels`. The logistic loss is + * z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x)) + * = z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x))) + * = z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x))) + * = z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x)) + * = (1 - z) * x + log(1 + exp(-x)) + * = x - x * z + log(1 + exp(-x)) + * + * For x < 0, to avoid overflow in exp(-x), we reformulate the above + * x - x * z + log(1 + exp(-x)) + * = log(exp(x)) - x * z + log(1 + exp(-x)) + * = - x * z + log(1 + exp(x)) + * + * Hence, to ensure stability and avoid overflow, the implementation uses + * this equivalent formulation: + * max(x, 0) - x * z + log(1 + exp(-abs(x))) + */ + var maxOutput = $logits.relu(); + var outputXTarget = $logits.mul($labels); + var sigmoidOutput = $logits.abs().neg().exp().log1p(); + return maxOutput.sub(outputXTarget).add(sigmoidOutput); + } + /** + * Computes the sigmoid cross entropy loss between two tensors. + * + * If labelSmoothing is nonzero, smooth the labels towards 1/2: + * + * newMulticlassLabels = multiclassLabels * (1 - labelSmoothing) + * + 0.5 * labelSmoothing + * + * @param multiClassLabels The ground truth output tensor of shape + * [batch_size, num_classes], same dimensions as 'predictions'. + * @param logits The predicted outputs. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param labelSmoothing If greater than 0, then smooth the labels. + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + /** @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' } */ + function sigmoidCrossEntropy_(multiClassLabels, logits, weights, labelSmoothing, reduction) { + if (labelSmoothing === void 0) { labelSmoothing = 0; } + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $multiClassLabels = convertToTensor(multiClassLabels, 'multiClassLabels', 'sigmoidCrossEntropy'); + var $logits = convertToTensor(logits, 'logits', 'sigmoidCrossEntropy'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'sigmoidCrossEntropy'); + } + assertShapesMatch($multiClassLabels.shape, $logits.shape, 'Error in sigmoidCrossEntropy: '); + if (labelSmoothing > 0) { + var labelSmoothingScalar = scalar(labelSmoothing); + var one = scalar(1); + var half = scalar(0.5); + $multiClassLabels = $multiClassLabels.mul(one.sub(labelSmoothingScalar)) + .add(half.mul(labelSmoothingScalar)); + } + var losses = sigmoidCrossEntropyWithLogits_($multiClassLabels, $logits); + return computeWeightedLoss(losses, $weights, reduction); + } + /** + * Computes the huber loss between two tensors. + * + * @param labels The ground truth output tensor, same dimensions as + * 'predictions'. + * @param predictions The predicted outputs. + * @param weights Tensor whose rank is either 0, or the same rank as + * `labels`, and must be broadcastable to `labels` (i.e., all dimensions + * must be either `1`, or the same as the corresponding `losses` + * dimension). + * @param delta Point where huber loss changes from quadratic to linear. + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction`. + */ + /** @doc {heading: 'Training', subheading: 'Losses', namespace: 'losses'} */ + function huberLoss_(labels, predictions, weights, delta, reduction) { + if (delta === void 0) { delta = 1.0; } + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $labels = convertToTensor(labels, 'labels', 'huberLoss'); + var $predictions = convertToTensor(predictions, 'predictions', 'huberLoss'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'huberLoss'); + } + assertShapesMatch($labels.shape, $predictions.shape, 'Error in huberLoss: '); + var deltaScalar = scalar(delta); + var error = $predictions.sub($labels).abs(); + var quadratic = minimum(error, deltaScalar); + var linear = error.sub(quadratic); + var losses = scalar(0.5).mul(quadratic.square()).add(deltaScalar.mul(linear)); + return computeWeightedLoss(losses, $weights, reduction); + } + /** + * Computes softmax cross entropy between logits and labels. + * + * Measures the probability error in discrete classification tasks in which + * the classes are mutually exclusive (each entry is in exactly one class). + * For example, each CIFAR-10 image is labeled with one and only one label: an + * image can be a dog or a truck, but not both. + * + * `NOTE`: While the classes are mutually exclusive, their probabilities need + * not be. All that is required is that each row of labels is a valid + * probability distribution. If they are not, the computation of the gradient + * will be incorrect. + * + * `WARNING`: This op expects unscaled logits, since it performs a softmax on + * logits internally for efficiency. Do not call this op with the output of + * softmax, as it will produce incorrect results. + * + * logits and labels must have the same shape, e.g. [batch_size, num_classes] + * and the same dtype. + * @param labels The labels array. + * @param logits The logits array. + * @param dim The dimension softmax would be performed on. Defaults to `-1` + * which indicates the last dimension. + */ + function softmaxCrossEntropyWithLogits_(labels, logits, dim) { + if (dim === void 0) { dim = -1; } + if (dim === -1) { + dim = logits.rank - 1; + } + if (dim !== logits.rank - 1) { + throw Error("Softmax cross entropy along a non-last dimension is not yet " + + ("supported. Labels / logits was rank " + logits.rank + " ") + + ("and dim was " + dim)); + } + // Use a custom gradient for numerical stability. + var customOp = customGrad(function (labels, logits, save) { + // Reference: + // 1. http://cs231n.github.io/linear-classify/#softmax + // 2. https://blog.feedly.com/tricks-of-the-trade-logsumexp/ + var keepDims = true; + var lse = logits.logSumExp([dim], keepDims); + var logResult = logits.toFloat().sub(lse); + save([labels, logResult]); + var costVector = logResult.mul(labels).neg(); + var value = costVector.sum([dim]); + var gradFunc = function (dy, saved) { + var labels = saved[0], logResult = saved[1]; + var dyShape = expandShapeToKeepDim(dy.shape, [dim]); + return [ + dy.reshape(dyShape).mul(labels.toFloat().sub(logResult.exp())), + dy.reshape(dyShape).mul(logResult.exp().sub(labels.toFloat())), + ]; + }; + return { value: value, gradFunc: gradFunc }; + }); + return customOp(labels, logits); + } + /** + * Computes the softmax cross entropy loss between two tensors. + * + * If labelSmoothing is nonzero, smooth the labels towards 1/2: + * + * newOnehotLabels = onehotLabels * (1 - labelSmoothing) + * + labelSmoothing / numClasses + * + * @param onehotLabels One hot encoded labels + * [batch_size, num_classes], same dimensions as 'predictions'. + * @param logits The predicted outputs. + * @param weights Tensor whose rank is either 0, or 1, and must be + * broadcastable to `loss` of shape [batch_size] + * @param labelSmoothing If greater than 0, then smooth the labels. + * @param reduction Type of reduction to apply to loss. Should be of type + * `Reduction` + */ + /** @doc { heading: 'Training', subheading: 'Losses', namespace: 'losses' } */ + function softmaxCrossEntropy_(onehotLabels, logits, weights, labelSmoothing, reduction) { + if (labelSmoothing === void 0) { labelSmoothing = 0; } + if (reduction === void 0) { reduction = exports.Reduction.SUM_BY_NONZERO_WEIGHTS; } + var $onehotLabels = convertToTensor(onehotLabels, 'onehotLabels', 'softmaxCrossEntropy'); + var $logits = convertToTensor(logits, 'logits', 'softmaxCrossEntropy'); + var $weights = null; + if (weights != null) { + $weights = convertToTensor(weights, 'weights', 'softmaxCrossEntropy'); + } + assertShapesMatch($onehotLabels.shape, $logits.shape, 'Error in softmaxCrossEntropy: '); + if (labelSmoothing > 0) { + var labelSmoothingScalar = scalar(labelSmoothing); + var one = scalar(1); + var numClasses = scalar($onehotLabels.shape[1]); + $onehotLabels = $onehotLabels.mul(one.sub(labelSmoothingScalar)) + .add(labelSmoothingScalar.div(numClasses)); + } + var losses = softmaxCrossEntropyWithLogits_($onehotLabels, $logits); + return computeWeightedLoss(losses, $weights, reduction); + } + var absoluteDifference = op({ absoluteDifference_: absoluteDifference_ }); + var computeWeightedLoss = op({ computeWeightedLoss_: computeWeightedLoss_ }); + var cosineDistance = op({ cosineDistance_: cosineDistance_ }); + var hingeLoss = op({ hingeLoss_: hingeLoss_ }); + var huberLoss = op({ huberLoss_: huberLoss_ }); + var logLoss = op({ logLoss_: logLoss_ }); + var meanSquaredError = op({ meanSquaredError_: meanSquaredError_ }); + var sigmoidCrossEntropy = op({ sigmoidCrossEntropy_: sigmoidCrossEntropy_ }); + var softmaxCrossEntropy = op({ softmaxCrossEntropy_: softmaxCrossEntropy_ }); + + var loss_ops = /*#__PURE__*/Object.freeze({ + get Reduction () { return exports.Reduction; }, + absoluteDifference: absoluteDifference, + computeWeightedLoss: computeWeightedLoss, + cosineDistance: cosineDistance, + hingeLoss: hingeLoss, + huberLoss: huberLoss, + logLoss: logLoss, + meanSquaredError: meanSquaredError, + sigmoidCrossEntropy: sigmoidCrossEntropy, + softmaxCrossEntropy: softmaxCrossEntropy + }); + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Gram-Schmidt orthogonalization. + * + * ```js + * const x = tf.tensor2d([[1, 2], [3, 4]]); + * let y = tf.linalg.gramSchmidt(x); + * y.print(); + * console.log('Othogonalized:'); + * y.dot(y.transpose()).print(); // should be nearly the identity matrix. + * console.log('First row direction maintained:'); + * const data = await y.array(); + * console.log(data[0][1] / data[0][0]); // should be nearly 2. + * ``` + * + * @param xs The vectors to be orthogonalized, in one of the two following + * formats: + * - An Array of `tf.Tensor1D`. + * - A `tf.Tensor2D`, i.e., a matrix, in which case the vectors are the rows + * of `xs`. + * In each case, all the vectors must have the same length and the length + * must be greater than or equal to the number of vectors. + * @returns The orthogonalized and normalized vectors or matrix. + * Orthogonalization means that the vectors or the rows of the matrix + * are orthogonal (zero inner products). Normalization means that each + * vector or each row of the matrix has an L2 norm that equals `1`. + */ + /** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ + function gramSchmidt_(xs) { + var inputIsTensor2D; + if (Array.isArray(xs)) { + inputIsTensor2D = false; + assert(xs != null && xs.length > 0, function () { return 'Gram-Schmidt process: input must not be null, undefined, or ' + + 'empty'; }); + var dim_1 = xs[0].shape[0]; + var _loop_1 = function (i) { + assert(xs[i].shape[0] === dim_1, function () { + return 'Gram-Schmidt: Non-unique lengths found in the input vectors: ' + + ("(" + xs[i].shape[0] + " vs. " + dim_1 + ")"); + }); + }; + for (var i = 1; i < xs.length; ++i) { + _loop_1(i); + } + } + else { + inputIsTensor2D = true; + xs = split(xs, xs.shape[0], 0).map(function (x) { return squeeze(x, [0]); }); + } + assert(xs.length <= xs[0].shape[0], function () { return "Gram-Schmidt: Number of vectors (" + xs.length + ") exceeds " + + ("number of dimensions (" + xs[0].shape[0] + ")."); }); + var ys = []; + var xs1d = xs; + var _loop_2 = function (i) { + ys.push(ENGINE.tidy(function () { + var x = xs1d[i]; + if (i > 0) { + for (var j = 0; j < i; ++j) { + var proj = sum$1(ys[j].mulStrict(x)).mul(ys[j]); + x = x.sub(proj); + } + } + return x.div(norm(x, 'euclidean')); + })); + }; + for (var i = 0; i < xs.length; ++i) { + _loop_2(i); + } + if (inputIsTensor2D) { + return stack(ys, 0); + } + else { + return ys; + } + } + /** + * Compute QR decomposition of m-by-n matrix using Householder transformation. + * + * Implementation based on + * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] + * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) + * + * ```js + * const a = tf.tensor2d([[1, 2], [3, 4]]); + * let [q, r] = tf.linalg.qr(a); + * console.log('Q'); + * q.print(); + * console.log('R'); + * r.print(); + * console.log('Orthogonalized'); + * q.dot(q.transpose()).print() // should be nearly the identity matrix. + * console.log('Reconstructed'); + * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]]; + * ``` + * + * @param x The `tf.Tensor` to be QR-decomposed. Must have rank >= 2. Suppose + * it has the shape `[..., M, N]`. + * @param fullMatrices An optional boolean parameter. Defaults to `false`. + * If `true`, compute full-sized `Q`. If `false` (the default), + * compute only the leading N columns of `Q` and `R`. + * @returns An `Array` of two `tf.Tensor`s: `[Q, R]`. `Q` is a unitary matrix, + * i.e., its columns all have unit norm and are mutually orthogonal. + * If `M >= N`, + * If `fullMatrices` is `false` (default), + * - `Q` has a shape of `[..., M, N]`, + * - `R` has a shape of `[..., N, N]`. + * If `fullMatrices` is `true` (default), + * - `Q` has a shape of `[..., M, M]`, + * - `R` has a shape of `[..., M, N]`. + * If `M < N`, + * - `Q` has a shape of `[..., M, M]`, + * - `R` has a shape of `[..., M, N]`. + * @throws If the rank of `x` is less than 2. + */ + /** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ + function qr_(x, fullMatrices) { + if (fullMatrices === void 0) { fullMatrices = false; } + if (x.rank < 2) { + throw new Error("qr() requires input tensor to have a rank >= 2, but got rank " + x.rank); + } + else if (x.rank === 2) { + return qr2d(x, fullMatrices); + } + else { + // Rank > 2. + // TODO(cais): Below we split the input into individual 2D tensors, + // perform QR decomposition on them and then stack the results back + // together. We should explore whether this can be parallelized. + var outerDimsProd = x.shape.slice(0, x.shape.length - 2) + .reduce(function (value, prev) { return value * prev; }); + var x2ds = unstack(x.reshape([ + outerDimsProd, x.shape[x.shape.length - 2], + x.shape[x.shape.length - 1] + ]), 0); + var q2ds_1 = []; + var r2ds_1 = []; + x2ds.forEach(function (x2d) { + var _a = qr2d(x2d, fullMatrices), q2d = _a[0], r2d = _a[1]; + q2ds_1.push(q2d); + r2ds_1.push(r2d); + }); + var q = stack(q2ds_1, 0).reshape(x.shape); + var r = stack(r2ds_1, 0).reshape(x.shape); + return [q, r]; + } + } + function qr2d(x, fullMatrices) { + if (fullMatrices === void 0) { fullMatrices = false; } + return ENGINE.tidy(function () { + if (x.shape.length !== 2) { + throw new Error("qr2d() requires a 2D Tensor, but got a " + x.shape.length + "D Tensor."); + } + var m = x.shape[0]; + var n = x.shape[1]; + var q = eye(m); // Orthogonal transform so far. + var r = x.clone(); // Transformed matrix so far. + var one2D = tensor2d([[1]], [1, 1]); + var w = one2D.clone(); + var iters = m >= n ? n : m; + var _loop_3 = function (j) { + var _a; + // This tidy within the for-loop ensures we clean up temporary + // tensors as soon as they are no longer needed. + var rTemp = r; + var wTemp = w; + var qTemp = q; + _a = ENGINE.tidy(function () { + // Find H = I - tau * w * w', to put zeros below R(j, j). + var rjEnd1 = r.slice([j, j], [m - j, 1]); + var normX = rjEnd1.norm(); + var rjj = r.slice([j, j], [1, 1]); + // The sign() function returns 0 on 0, which causes division by zero. + var s = tensor2d([[-1]]).where(rjj.greater(0), tensor2d([[1]])); + var u1 = rjj.sub(s.mul(normX)); + var wPre = rjEnd1.div(u1); + if (wPre.shape[0] === 1) { + w = one2D.clone(); + } + else { + w = one2D.concat(wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]), 0); + } + var tau = s.matMul(u1).div(normX).neg(); + // -- R := HR, Q := QH. + var rjEndAll = r.slice([j, 0], [m - j, n]); + var tauTimesW = tau.mul(w); + if (j === 0) { + r = rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll))); + } + else { + var rTimesTau = rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll))); + r = r.slice([0, 0], [j, n]).concat(rTimesTau, 0); + } + var qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]); + if (j === 0) { + q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose())); + } + else { + var qTimesTau = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose())); + q = q.slice([0, 0], [m, j]).concat(qTimesTau, 1); + } + return [w, r, q]; + }), w = _a[0], r = _a[1], q = _a[2]; + dispose([rTemp, wTemp, qTemp]); + }; + for (var j = 0; j < iters; ++j) { + _loop_3(j); + } + if (!fullMatrices && m > n) { + q = q.slice([0, 0], [m, n]); + r = r.slice([0, 0], [n, n]); + } + return [q, r]; + }); + } + var gramSchmidt = op({ gramSchmidt_: gramSchmidt_ }); + var qr = op({ qr_: qr_ }); + + var linalg_ops = /*#__PURE__*/Object.freeze({ + gramSchmidt: gramSchmidt, + qr: qr + }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Bilinear resize a batch of 3D images to a new shape. + * + * @param images The images, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param size The new shape `[newHeight, newWidth]` to resize the + * images to. Each channel is resized individually. + * @param alignCorners Defaults to False. If true, rescale + * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4 + * corners of images and resized images. If false, rescale by + * `new_height / height`. Treat similarly the width dimension. + */ + /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ + function resizeBilinear_(images, size, alignCorners) { + if (alignCorners === void 0) { alignCorners = false; } + var $images = convertToTensor(images, 'images', 'resizeBilinear'); + assert($images.rank === 3 || $images.rank === 4, function () { return "Error in resizeBilinear: x must be rank 3 or 4, but got " + + ("rank " + $images.rank + "."); }); + assert(size.length === 2, function () { return "Error in resizeBilinear: new shape must 2D, but got shape " + + (size + "."); }); + var batchImages = $images; + var reshapedTo4D = false; + if ($images.rank === 3) { + reshapedTo4D = true; + batchImages = + $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]); + } + var newHeight = size[0], newWidth = size[1]; + var forward = function (backend, save) { + save([batchImages]); + return backend.resizeBilinear(batchImages, newHeight, newWidth, alignCorners); + }; + var backward = function (dy, saved) { + return { + batchImages: function () { return ENGINE.runKernelFunc(function (backend) { return backend.resizeBilinearBackprop(dy, saved[0], alignCorners); }, {}); } + }; + }; + var res = ENGINE.runKernelFunc(forward, { batchImages: batchImages }, backward); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * NearestNeighbor resize a batch of 3D images to a new shape. + * + * @param images The images, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param size The new shape `[newHeight, newWidth]` to resize the + * images to. Each channel is resized individually. + * @param alignCorners Defaults to False. If true, rescale + * input by `(new_height - 1) / (height - 1)`, which exactly aligns the 4 + * corners of images and resized images. If false, rescale by + * `new_height / height`. Treat similarly the width dimension. + */ + /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ + function resizeNearestNeighbor_(images, size, alignCorners) { + if (alignCorners === void 0) { alignCorners = false; } + var $images = convertToTensor(images, 'images', 'resizeNearestNeighbor'); + assert($images.rank === 3 || $images.rank === 4, function () { return "Error in resizeNearestNeighbor: x must be rank 3 or 4, but got " + + ("rank " + $images.rank + "."); }); + assert(size.length === 2, function () { + return "Error in resizeNearestNeighbor: new shape must 2D, but got shape " + + (size + "."); + }); + assert($images.dtype === 'float32' || $images.dtype === 'int32', function () { return '`images` must have `int32` or `float32` as dtype'; }); + var batchImages = $images; + var reshapedTo4D = false; + if ($images.rank === 3) { + reshapedTo4D = true; + batchImages = + $images.as4D(1, $images.shape[0], $images.shape[1], $images.shape[2]); + } + var newHeight = size[0], newWidth = size[1]; + var forward = function (backend, save) { + save([batchImages]); + return backend.resizeNearestNeighbor(batchImages, newHeight, newWidth, alignCorners); + }; + var backward = function (dy, saved) { + return { + batchImages: function () { return ENGINE.runKernelFunc(function (backend) { return backend.resizeNearestNeighborBackprop(dy, saved[0], alignCorners); }, {}); } + }; + }; + var res = ENGINE.runKernelFunc(forward, { batchImages: batchImages }, backward); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * Performs non maximum suppression of bounding boxes based on + * iou (intersection over union) + * + * @param boxes a 2d tensor of shape `[numBoxes, 4]`. Each entry is + * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the corners of + * the bounding box. + * @param scores a 1d tensor providing the box scores of shape `[numBoxes]`. + * @param maxOutputSize The maximum number of boxes to be selected. + * @param iouThreshold A float representing the threshold for deciding whether + * boxes overlap too much with respect to IOU. Must be between [0, 1]. + * Defaults to 0.5 (50% box overlap). + * @param scoreThreshold A threshold for deciding when to remove boxes based + * on score. Defaults to -inf, which means any score is accepted. + * @return A 1D tensor with the selected box indices. + */ + /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ + function nonMaxSuppression_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + if (iouThreshold === void 0) { iouThreshold = 0.5; } + if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; } + var $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppression'); + var $scores = convertToTensor(scores, 'scores', 'nonMaxSuppression'); + var inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); + maxOutputSize = inputs.maxOutputSize; + iouThreshold = inputs.iouThreshold; + scoreThreshold = inputs.scoreThreshold; + return ENGINE.runKernelFunc(function (b) { return b.nonMaxSuppression($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); }, { $boxes: $boxes }); + } + /** This is the async version of `nonMaxSuppression` */ + function nonMaxSuppressionAsync_(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + if (iouThreshold === void 0) { iouThreshold = 0.5; } + if (scoreThreshold === void 0) { scoreThreshold = Number.NEGATIVE_INFINITY; } + return __awaiter(this, void 0, void 0, function () { + var $boxes, $scores, inputs, boxesAndScores, boxesVals, scoresVals, res; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + $boxes = convertToTensor(boxes, 'boxes', 'nonMaxSuppressionAsync'); + $scores = convertToTensor(scores, 'scores', 'nonMaxSuppressionAsync'); + inputs = nonMaxSuppSanityCheck($boxes, $scores, maxOutputSize, iouThreshold, scoreThreshold); + maxOutputSize = inputs.maxOutputSize; + iouThreshold = inputs.iouThreshold; + scoreThreshold = inputs.scoreThreshold; + return [4 /*yield*/, Promise.all([$boxes.data(), $scores.data()])]; + case 1: + boxesAndScores = _a.sent(); + boxesVals = boxesAndScores[0]; + scoresVals = boxesAndScores[1]; + res = nonMaxSuppressionImpl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + if ($boxes !== boxes) { + $boxes.dispose(); + } + if ($scores !== scores) { + $scores.dispose(); + } + return [2 /*return*/, res]; + } + }); + }); + } + function nonMaxSuppSanityCheck(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + if (iouThreshold == null) { + iouThreshold = 0.5; + } + if (scoreThreshold == null) { + scoreThreshold = Number.NEGATIVE_INFINITY; + } + var numBoxes = boxes.shape[0]; + maxOutputSize = Math.min(maxOutputSize, numBoxes); + assert(0 <= iouThreshold && iouThreshold <= 1, function () { return "iouThreshold must be in [0, 1], but was '" + iouThreshold + "'"; }); + assert(boxes.rank === 2, function () { return "boxes must be a 2D tensor, but was of rank '" + boxes.rank + "'"; }); + assert(boxes.shape[1] === 4, function () { + return "boxes must have 4 columns, but 2nd dimension was " + boxes.shape[1]; + }); + assert(scores.rank === 1, function () { return 'scores must be a 1D tensor'; }); + assert(scores.shape[0] === numBoxes, function () { return "scores has incompatible shape with boxes. Expected " + numBoxes + ", " + + ("but was " + scores.shape[0]); }); + return { maxOutputSize: maxOutputSize, iouThreshold: iouThreshold, scoreThreshold: scoreThreshold }; + } + /** + * Extracts crops from the input image tensor and resizes them using bilinear + * sampling or nearest neighbor sampling (possibly with aspect ratio change) + * to a common output size specified by crop_size. + * + * @param image 4d tensor of shape `[batch,imageHeight,imageWidth, depth]`, + * where imageHeight and imageWidth must be positive, specifying the + * batch of images from which to take crops + * @param boxes 2d float32 tensor of shape `[numBoxes, 4]`. Each entry is + * `[y1, x1, y2, x2]`, where `(y1, x1)` and `(y2, x2)` are the normalized + * coordinates of the box in the boxInd[i]'th image in the batch + * @param boxInd 1d int32 tensor of shape `[numBoxes]` with values in range + * `[0, batch)` that specifies the image that the `i`-th box refers to. + * @param cropSize 1d int32 tensor of 2 elements `[cropHeigh, cropWidth]` + * specifying the size to which all crops are resized to. + * @param method Optional string from `'bilinear' | 'nearest'`, + * defaults to bilinear, which specifies the sampling method for resizing + * @param extrapolationValue A threshold for deciding when to remove boxes based + * on score. Defaults to 0. + * @return A 4D tensor of the shape `[numBoxes,cropHeight,cropWidth,depth]` + */ + /** @doc {heading: 'Operations', subheading: 'Images', namespace: 'image'} */ + function cropAndResize_(image, boxes, boxInd, cropSize, method, extrapolationValue) { + var $image = convertToTensor(image, 'image', 'cropAndResize', 'float32'); + var $boxes = convertToTensor(boxes, 'boxes', 'cropAndResize', 'float32'); + var $boxInd = convertToTensor(boxInd, 'boxInd', 'cropAndResize', 'int32'); + method = method || 'bilinear'; + extrapolationValue = extrapolationValue || 0; + var numBoxes = $boxes.shape[0]; + assert($image.rank === 4, function () { return 'Error in cropAndResize: image must be rank 4,' + + ("but got rank " + $image.rank + "."); }); + assert($boxes.rank === 2 && $boxes.shape[1] === 4, function () { return "Error in cropAndResize: boxes must be have size [" + numBoxes + ",4] " + + ("but had shape " + $boxes.shape + "."); }); + assert($boxInd.rank === 1 && $boxInd.shape[0] === numBoxes, function () { return "Error in cropAndResize: boxInd must be have size [" + numBoxes + "] " + + ("but had shape " + $boxes.shape + "."); }); + assert(cropSize.length === 2, function () { return "Error in cropAndResize: cropSize must be of length 2, but got " + + ("length " + cropSize.length + "."); }); + assert(cropSize[0] >= 1 && cropSize[1] >= 1, function () { return "cropSize must be atleast [1,1], but was " + cropSize; }); + assert(method === 'bilinear' || method === 'nearest', function () { return "method must be bilinear or nearest, but was " + method; }); + var forward = function (backend, save) { + return backend.cropAndResize($image, $boxes, $boxInd, cropSize, method, extrapolationValue); + }; + var res = ENGINE.runKernelFunc(forward, { images: $image, boxes: $boxes, boxInd: $boxInd }, null /* der */, 'CropAndResize', { method: method, extrapolationValue: extrapolationValue, cropSize: cropSize }); + return res; + } + var resizeBilinear = op({ resizeBilinear_: resizeBilinear_ }); + var resizeNearestNeighbor = op({ resizeNearestNeighbor_: resizeNearestNeighbor_ }); + var nonMaxSuppression = op({ nonMaxSuppression_: nonMaxSuppression_ }); + var nonMaxSuppressionAsync = nonMaxSuppressionAsync_; + var cropAndResize = op({ cropAndResize_: cropAndResize_ }); + + var image_ops = /*#__PURE__*/Object.freeze({ + resizeBilinear: resizeBilinear, + resizeNearestNeighbor: resizeNearestNeighbor, + nonMaxSuppression: nonMaxSuppression, + nonMaxSuppressionAsync: nonMaxSuppressionAsync, + cropAndResize: cropAndResize + }); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Whether we should call fused ops. + var shouldFuse = function (gradientDepth, activation) { + var gradientMode = gradientDepth > 0; + return !gradientMode && (activation === 'linear' || activation === 'relu'); + }; + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // Returns gradient for fused activation. + var getFusedDyActivation = function (dy, y, activation) { + if (activation == null || activation === 'linear') { + return dy; + } + if (activation === 'relu') { + return dy.mul(y.step()); + } + throw new Error("Gradient for activation " + activation + " has not been " + + "implemented yet."); + }; + // Returns gradient for fused bias. + var getFusedBiasGradient = function (bias, dyActivation) { + var res = dyActivation; + var reduceAxes = getReductionAxes(bias.shape, dyActivation.shape); + if (reduceAxes.length > 0) { + res = res.sum(reduceAxes); + } + return res.reshape(bias.shape); + }; + var applyActivation = function (x, activation, preluActivationWeights) { + if (activation === 'linear') { + return x; + } + else if (activation === 'relu') { + return relu(x); + } + else if (activation === 'elu') { + return elu(x); + } + else if (activation === 'relu6') { + return relu6(x); + } + else if (activation === 'prelu') { + return prelu(x, preluActivationWeights); + } + throw new Error("Unknown fused activation " + activation + "."); + }; + /** + * Computes the dot product of two matrices with optional activation and bias. + * + * ```js + * const a = tf.tensor2d([-1, -2], [1, 2]); + * const b = tf.tensor2d([1, 2, 3, 4], [2, 2]); + * const bias = tf.tensor2d([1, 2], [1, 2]); + * + * tf.fused.matMul({a, b, bias, activation: 'relu'}).print(); + * ``` + * + * @param obj An object with the following properties: + * - `a` First matrix in dot product operation. + * - `b` Second matrix in dot product operation. + * - `transposeA` If true, `a` is transposed before multiplication. + * - `transposeB` If true, `b` is transposed before multiplication. + * - `bias` Matrix to be added to the result. + * - `activation` Name of activation kernel (defaults to `linear`). + * - `preluActivationWeights` Tensor of prelu weights. + */ + function matMul_$1(_a) { + var _b; + var a = _a.a, b = _a.b, _c = _a.transposeA, transposeA = _c === void 0 ? false : _c, _d = _a.transposeB, transposeB = _d === void 0 ? false : _d, bias = _a.bias, _e = _a.activation, activation = _e === void 0 ? 'linear' : _e, preluActivationWeights = _a.preluActivationWeights; + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + var result = matMul(a, b, transposeA, transposeB); + if (bias != null) { + result = add(result, bias); + } + return applyActivation(result, activation, preluActivationWeights); + } + var $a = convertToTensor(a, 'a', 'fused matMul'); + var $b = convertToTensor(b, 'b', 'fused matMul'); + _b = makeTypesMatch($a, $b), $a = _b[0], $b = _b[1]; + var innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; + var innerShapeB = transposeB ? $b.shape[$b.rank - 1] : $b.shape[$b.rank - 2]; + var outerShapeA = transposeA ? $a.shape[$a.rank - 1] : $a.shape[$a.rank - 2]; + var outerShapeB = transposeB ? $b.shape[$b.rank - 2] : $b.shape[$b.rank - 1]; + var outerDimsA = $a.shape.slice(0, -2); + var outerDimsB = $b.shape.slice(0, -2); + var batchDimA = sizeFromShape(outerDimsA); + var batchDimB = sizeFromShape(outerDimsB); + assert($a.rank >= 2 && $b.rank >= 2 && $a.rank === $b.rank, function () { + return "Error in fused matMul: inputs must have the same rank of at least " + + ("2, got ranks " + $a.rank + " and " + $b.rank + "."); + }); + assert(arraysEqual(outerDimsA, outerDimsB), function () { return "Error in fused matMul: outer dimensions (" + outerDimsA + ") and (" + + (outerDimsB + ") of Tensors with shapes " + $a.shape + " and ") + + ($b.shape + " must match."); }); + assert(innerShapeA === innerShapeB, function () { return "Error in fused matMul: inner shapes (" + innerShapeA + ") and (" + + (innerShapeB + ") of Tensors with shapes " + $a.shape + " and ") + + ($b.shape + " and transposeA=" + transposeA) + + (" and transposeB=" + transposeB + " must match."); }); + var outShape = $a.shape.slice(0, -2).concat([outerShapeA, outerShapeB]); + var a3D = transposeA ? $a.as3D(batchDimA, innerShapeA, outerShapeA) : + $a.as3D(batchDimA, outerShapeA, innerShapeA); + var b3D = transposeB ? $b.as3D(batchDimB, outerShapeB, innerShapeB) : + $b.as3D(batchDimB, innerShapeB, outerShapeB); + var $bias; + if (bias != null) { + $bias = convertToTensor(bias, 'bias', 'fused matMul'); + $bias = makeTypesMatch($bias, $a)[0]; + assertAndGetBroadcastShape(outShape, $bias.shape); + } + var $preluActivationWeights; + if (preluActivationWeights != null) { + $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused matMul'); + } + var grad = function (dy, saved) { + var a3D = saved[0], b3D = saved[1], y = saved[2]; + var dyActivation = getFusedDyActivation(dy, y, activation); + var biasGradient = {}; + if (bias != null) { + biasGradient = { $bias: function () { return getFusedBiasGradient($bias, dyActivation); } }; + } + if (!transposeA && !transposeB) { + return Object.assign({ + $a: function () { return dyActivation.matMul(b3D, false, true); }, + $b: function () { return a3D.matMul(dyActivation, true, false); } + }, biasGradient); + } + else if (!transposeA && transposeB) { + return Object.assign({ + $a: function () { return dyActivation.matMul(b3D, false, false); }, + $b: function () { return dyActivation.matMul(a3D, true, false); } + }, biasGradient); + } + else if (transposeA && !transposeB) { + return Object.assign({ + $a: function () { return b3D.matMul(dyActivation, false, true); }, + $b: function () { return a3D.matMul(dyActivation, false, false); } + }, biasGradient); + } + else { + return Object.assign({ + $a: function () { return b3D.matMul(dyActivation, true, true); }, + $b: function () { return dyActivation.matMul(a3D, true, true); } + }, biasGradient); + } + }; + var inputs = { $a: a3D, $b: b3D }; + if (bias != null) { + inputs.$bias = $bias; + } + if (preluActivationWeights != null) { + inputs.$preluActivationWeights = $preluActivationWeights; + } + var res = ENGINE.runKernelFunc(function (backend, save) { + var y = backend.fusedBatchMatMul({ + a: a3D, + b: b3D, + transposeA: transposeA, + transposeB: transposeB, + bias: $bias, + activation: activation, + preluActivationWeights: $preluActivationWeights + }); + save([a3D, b3D, y]); + return y; + }, inputs, grad); + return res.reshape(outShape); + } + /** + * Computes a 2D convolution over the input x, optionally fused with adding a + * bias and applying an activation. + * + * ```js + * const inputDepth = 2; + * const inShape = [2, 2, 2, inputDepth]; + * const outputDepth = 2; + * const fSize = 1; + * const pad = 0; + * const strides = 1; + * + * const x = tf.tensor4d( [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + * 16], inShape); + * const w = tf.tensor4d([-1, 1, -2, 0.5], [fSize, fSize, inputDepth, + * outputDepth]); + * + * tf.fused.conv2d({ x, filter: w, strides, pad, dataFormat: 'NHWC', + * dilations: [1, 1], bias: tf.scalar(5), activation: 'relu' }).print(); + * ``` + * + * @param obj An object with the following properties: + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter, rank 4, of shape + * `[filterHeight, filterWidth, inDepth, outDepth]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid` output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. Only "NHWC" is currently supported. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + * @param bias Tensor to be added to the result. + * @param activation Name of activation kernel (defaults to `linear`) to be + * applied + * after biasAdd. + * @param preluActivationWeights Tensor of prelu weights to be applied as part + * of a `prelu` activation, typically the same shape as `x`. + */ + function conv2d_$1(_a) { + var x = _a.x, filter = _a.filter, strides = _a.strides, pad = _a.pad, _b = _a.dataFormat, dataFormat = _b === void 0 ? 'NHWC' : _b, _c = _a.dilations, dilations = _c === void 0 ? [1, 1] : _c, dimRoundingMode = _a.dimRoundingMode, bias = _a.bias, _d = _a.activation, activation = _d === void 0 ? 'linear' : _d, preluActivationWeights = _a.preluActivationWeights; + activation = activation || 'linear'; + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + var result = conv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + if (bias != null) { + result = add(result, bias); + } + return applyActivation(result, activation, preluActivationWeights); + } + var $x = convertToTensor(x, 'x', 'conv2d'); + var $filter = convertToTensor(filter, 'filter', 'conv2d'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + assert(x4D.rank === 4, function () { return "Error in fused conv2d: input must be rank 4, but got rank " + + (x4D.rank + "."); }); + assert($filter.rank === 4, function () { return "Error in fused conv2d: filter must be rank 4, but got rank " + + ($filter.rank + "."); }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in fused conv2d: pad must be an integer when using, " + + ("dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + assert(x4D.shape[3] === $filter.shape[2], function () { return "Error in conv2d: depth of input (" + x4D.shape[3] + ") must match " + + ("input depth for filter " + $filter.shape[2] + "."); }); + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { return 'Error in conv2D: Either strides or dilations must be 1. ' + + ("Got strides " + strides + " and dilations '" + dilations + "'"); }); + assert(dataFormat === 'NHWC', function () { return "Error in conv2d: got dataFormat of " + dataFormat + " but only NHWC is currently supported."; }); + var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode); + var $bias; + if (bias != null) { + $bias = convertToTensor(bias, 'bias', 'fused conv2d'); + $bias = makeTypesMatch($bias, $x)[0]; + assertAndGetBroadcastShape(convInfo.outShape, $bias.shape); + } + var $preluActivationWeights; + if (preluActivationWeights != null) { + $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused conv2d'); + } + var grad = function (dy, saved) { + var _a = saved, $filter = _a[0], x4D = _a[1], y = _a[2]; + var dyActivation = getFusedDyActivation(dy, y, activation); + assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of fused conv2D: ' + + "dilation rates greater than 1 " + + ("are not yet supported in gradients. Got dilations '" + dilations + "'"); }); + var biasGradient = {}; + if (bias != null) { + biasGradient = { $bias: function () { return getFusedBiasGradient($bias, dyActivation); } }; + } + return Object.assign({ + x: function () { + return conv2dDerInput(x4D.shape, dyActivation, $filter, strides, pad); + }, + filter: function () { + return conv2dDerFilter(x4D, dyActivation, $filter.shape, strides, pad); + } + }, biasGradient); + }; + var inputs = { x: x4D, filter: $filter }; + if (bias != null) { + inputs.bias = $bias; + } + if (preluActivationWeights != null) { + inputs.preluActivationWeights = $preluActivationWeights; + } + var inputsToSave = [$filter, x4D]; + var outputsToSave = [true]; // Save the only output. + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.fusedConv2d({ + input: x4D, + filter: $filter, + convInfo: convInfo, + bias: $bias, + activation: activation, + preluActivationWeights: $preluActivationWeights + }); + save([$filter, x4D, res]); + return res; + }, inputs, grad, 'FusedConv2D', { convInfo: convInfo, activation: activation }, inputsToSave, outputsToSave); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + /** + * Computes depthwise 2D convolution, optionally fused with adding a + * bias and applying an activation. + * + * Given a 4D `input` array and a `filter` array of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]` containing + * `inChannels` convolutional filters of depth 1, this op applies a + * different filter to each input channel (expanding from 1 channel to + * `channelMultiplier` channels for each), then concatenates the results + * together. The output has `inChannels * channelMultiplier` channels. + * + * See + * [https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d]( + * https://www.tensorflow.org/api_docs/python/tf/nn/depthwise_conv2d) + * for more details. + * + * @param obj An object with the following properties: + * @param x The input tensor, of rank 4 or rank 3, of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is + * assumed. + * @param filter The filter tensor, rank 4, of shape + * `[filterHeight, filterWidth, inChannels, channelMultiplier]`. + * @param strides The strides of the convolution: `[strideHeight, + * strideWidth]`. If strides is a single number, then `strideHeight == + * strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single + * number, then `dilationHeight == dilationWidth`. If it is greater than + * 1, then all values of `strides` must be 1. + * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to + * "NHWC". Specify the data format of the input and output data. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. Only "NHWC" is currently supported. + * @param dimRoundingMode The rounding mode used when computing output + * dimensions if pad is a number. If none is provided, it will not round + * and error if the output is of fractional size. + * @param bias Tensor to be added to the result. + * @param activation Name of activation kernel (defaults to `linear`). + * @param preluActivationWeights Tensor of prelu weights to be applied as part + * of a `prelu` activation, typically the same shape as `x`. + */ + function depthwiseConv2d_$1(_a) { + var x = _a.x, filter = _a.filter, strides = _a.strides, pad = _a.pad, _b = _a.dataFormat, dataFormat = _b === void 0 ? 'NHWC' : _b, _c = _a.dilations, dilations = _c === void 0 ? [1, 1] : _c, dimRoundingMode = _a.dimRoundingMode, bias = _a.bias, _d = _a.activation, activation = _d === void 0 ? 'linear' : _d, preluActivationWeights = _a.preluActivationWeights; + if (shouldFuse(ENGINE.state.gradientDepth, activation) === false) { + var result = depthwiseConv2d(x, filter, strides, pad, dataFormat, dilations, dimRoundingMode); + if (bias != null) { + result = add(result, bias); + } + return applyActivation(result, activation, preluActivationWeights); + } + var $x = convertToTensor(x, 'x', 'depthwiseConv2d'); + var $filter = convertToTensor(filter, 'filter', 'depthwiseConv2d'); + var x4D = $x; + var reshapedTo4D = false; + if ($x.rank === 3) { + reshapedTo4D = true; + x4D = $x.as4D(1, $x.shape[0], $x.shape[1], $x.shape[2]); + } + assert(x4D.rank === 4, function () { return "Error in fused depthwiseConv2d: input must be rank 4, but got " + + ("rank " + x4D.rank + "."); }); + assert($filter.rank === 4, function () { return "Error in fused depthwiseConv2d: filter must be rank 4, " + + ("but got rank " + $filter.rank + "."); }); + assert(x4D.shape[3] === $filter.shape[2], function () { return "Error in fused depthwiseConv2d: number of input channels " + + ("(" + x4D.shape[3] + ") must match the inChannels dimension in ") + + ("filter " + $filter.shape[2] + "."); }); + if (dilations == null) { + dilations = [1, 1]; + } + assert(eitherStridesOrDilationsAreOne(strides, dilations), function () { + return 'Error in fused depthwiseConv2d: Either strides or dilations must ' + + ("be 1. Got strides " + strides + " and dilations '" + dilations + "'"); + }); + if (dimRoundingMode != null) { + assert(isInt(pad), function () { return "Error in fused depthwiseConv2d: pad must be an integer when " + + ("using dimRoundingMode " + dimRoundingMode + " but got pad " + pad + "."); }); + } + var convInfo = computeConv2DInfo(x4D.shape, $filter.shape, strides, dilations, pad, dimRoundingMode, true /* depthwise */); + var $bias; + if (bias != null) { + $bias = convertToTensor(bias, 'bias', 'fused conv2d'); + $bias = makeTypesMatch($bias, $x)[0]; + assertAndGetBroadcastShape(convInfo.outShape, $bias.shape); + } + var $preluActivationWeights; + if (preluActivationWeights != null) { + $preluActivationWeights = convertToTensor(preluActivationWeights, 'prelu weights', 'fused depthwiseConv2d'); + } + var grad = function (dy, saved) { + assert(tupleValuesAreOne(dilations), function () { return 'Error in gradient of fused depthwiseConv2d: dilation rates ' + + "greater than 1 are not yet supported. Got dilations " + + ("'" + dilations + "'"); }); + var x4D = saved[0], $filter = saved[1], y = saved[2]; + var dyActivation = getFusedDyActivation(dy, y, activation); + var biasGradient = {}; + if (bias != null) { + biasGradient = { $bias: function () { return getFusedBiasGradient($bias, dyActivation); } }; + } + return Object.assign({ + x: function () { return depthwiseConv2dDerInput(x4D.shape, dyActivation, $filter, convInfo); }, + $filter: function () { return depthwiseConv2dDerFilter(x4D, dyActivation, $filter.shape, convInfo); }, + }, biasGradient); + }; + var inputs = { x: x4D, $filter: $filter }; + if (bias != null) { + inputs.$bias = $bias; + } + if (preluActivationWeights != null) { + inputs.$preluActivationWeights = $preluActivationWeights; + } + var res = ENGINE.runKernelFunc(function (backend, save) { + var res = backend.fusedDepthwiseConv2D({ + input: x4D, + filter: $filter, + convInfo: convInfo, + bias: $bias, + activation: activation, + preluActivationWeights: $preluActivationWeights + }); + save([x4D, $filter, res]); + return res; + }, inputs, grad); + if (reshapedTo4D) { + return res.as3D(res.shape[1], res.shape[2], res.shape[3]); + } + return res; + } + var matMul$1 = op({ matMul_: matMul_$1 }); + var conv2d$1 = op({ conv2d_: conv2d_$1 }); + var depthwiseConv2d$1 = op({ depthwiseConv2d_: depthwiseConv2d_$1 }); + + var fused_ops = /*#__PURE__*/Object.freeze({ + matMul: matMul$1, + conv2d: conv2d$1, + depthwiseConv2d: depthwiseConv2d$1 + }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + var ops = /*#__PURE__*/Object.freeze({ + image: image_ops, + linalg: linalg_ops, + losses: loss_ops, + spectral: spectral_ops, + fused: fused_ops, + signal: signal_ops, + square: square, + conv1d: conv1d, + conv2d: conv2d, + conv3d: conv3d, + depthwiseConv2d: depthwiseConv2d, + separableConv2d: separableConv2d, + conv2dTranspose: conv2dTranspose, + conv3dTranspose: conv3dTranspose, + op: op, + batchNormalization2d: batchNormalization2d, + batchNormalization3d: batchNormalization3d, + batchNormalization4d: batchNormalization4d, + batchNormalization: batchNormalization, + batchNorm: batchNorm, + batchNorm2d: batchNorm2d, + batchNorm3d: batchNorm3d, + batchNorm4d: batchNorm4d, + booleanMaskAsync: booleanMaskAsync, + complex: complex, + real: real, + imag: imag, + concat: concat, + concat1d: concat1d, + concat2d: concat2d, + concat3d: concat3d, + concat4d: concat4d, + split: split, + matMul: matMul, + dot: dot, + outerProduct: outerProduct, + reverse: reverse, + reverse1d: reverse1d, + reverse2d: reverse2d, + reverse3d: reverse3d, + reverse4d: reverse4d, + maxPool: maxPool, + avgPool: avgPool, + pool: pool, + maxPool3d: maxPool3d, + avgPool3d: avgPool3d, + slice: slice, + slice1d: slice1d, + slice2d: slice2d, + slice3d: slice3d, + slice4d: slice4d, + abs: abs, + acos: acos, + acosh: acosh, + asin: asin, + asinh: asinh, + atan: atan, + atanh: atanh, + ceil: ceil, + clipByValue: clipByValue, + cos: cos, + cosh: cosh, + erf: erf, + exp: exp, + expm1: expm1, + floor: floor, + log: log, + log1p: log1p, + logSigmoid: logSigmoid, + neg: neg, + reciprocal: reciprocal, + round: round, + rsqrt: rsqrt, + sigmoid: sigmoid, + sign: sign, + isNaN: isNaN$1, + isInf: isInf, + isFinite: isFinite$1, + sin: sin, + sinh: sinh, + softplus: softplus, + sqrt: sqrt, + step: step, + tan: tan, + tanh: tanh$1, + all: all, + any: any, + argMax: argMax, + argMin: argMin, + logSumExp: logSumExp, + max: max, + mean: mean, + min: min, + moments: moments, + sum: sum$1, + prod: prod, + equal: equal, + equalStrict: equalStrict, + greater: greater, + greaterEqual: greaterEqual, + greaterEqualStrict: greaterEqualStrict, + greaterStrict: greaterStrict, + less: less, + lessEqual: lessEqual, + lessEqualStrict: lessEqualStrict, + lessStrict: lessStrict, + notEqual: notEqual, + notEqualStrict: notEqualStrict, + add: add, + addN: addN, + addStrict: addStrict, + atan2: atan2, + div: div, + divNoNan: divNoNan, + divStrict: divStrict, + floorDiv: floorDiv, + maximum: maximum, + maximumStrict: maximumStrict, + minimum: minimum, + minimumStrict: minimumStrict, + mod: mod, + modStrict: modStrict, + mul: mul, + mulStrict: mulStrict, + pow: pow, + powStrict: powStrict, + squaredDifference: squaredDifference, + squaredDifferenceStrict: squaredDifferenceStrict, + sub: sub, + subStrict: subStrict, + elu: elu, + leakyRelu: leakyRelu, + prelu: prelu, + relu: relu, + relu6: relu6, + selu: selu, + logicalAnd: logicalAnd, + logicalNot: logicalNot, + logicalOr: logicalOr, + logicalXor: logicalXor, + where: where, + whereAsync: whereAsync, + buffer: buffer, + print: print, + batchToSpaceND: batchToSpaceND, + cast: cast, + clone: clone, + cumsum: cumsum, + depthToSpace: depthToSpace, + expandDims: expandDims, + eye: eye, + multinomial: multinomial, + oneHot: oneHot, + pad: pad, + pad1d: pad1d, + pad2d: pad2d, + pad3d: pad3d, + pad4d: pad4d, + rand: rand, + randomNormal: randomNormal, + randomGamma: randomGamma, + randomUniform: randomUniform, + reshape: reshape, + spaceToBatchND: spaceToBatchND, + squeeze: squeeze, + stack: stack, + tile: tile, + truncatedNormal: truncatedNormal, + unstack: unstack, + setdiff1dAsync: setdiff1dAsync, + fill: fill, + linspace: linspace, + ones: ones$1, + range: range, + scalar: scalar, + tensor: tensor, + tensor1d: tensor1d, + tensor2d: tensor2d, + tensor3d: tensor3d, + tensor4d: tensor4d, + tensor5d: tensor5d, + tensor6d: tensor6d, + variable: variable, + zeros: zeros, + onesLike: onesLike, + zerosLike: zerosLike, + transpose: transpose, + softmax: softmax, + logSoftmax: logSoftmax, + localResponseNormalization: localResponseNormalization, + norm: norm, + gather: gather, + unsortedSegmentSum: unsortedSegmentSum, + basicLSTMCell: basicLSTMCell, + multiRNNCell: multiRNNCell, + movingAverage: movingAverage, + stridedSlice: stridedSlice, + topk: topk, + scatterND: scatterND, + fft: fft, + ifft: ifft, + rfft: rfft, + irfft: irfft, + sparseToDense: sparseToDense, + gatherND: gatherND, + diag: diag, + dropout: dropout, + hannWindow: hannWindow, + hammingWindow: hammingWindow, + frame: frame, + stft: stft, + inTopKAsync: inTopKAsync + }); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function assertNotComplex(tensor, opName) { + if (!Array.isArray(tensor)) { + tensor = [tensor]; + } + tensor.forEach(function (t) { + if (t != null) { + assert(t.dtype !== 'complex64', function () { return opName + " does not support complex64 tensors."; }); + } + }); + } + + /** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + function mapActivation(backend, x, activation, preluActivationWeights) { + if (activation === 'linear') { + return backend.linear(x); + } + else if (activation === 'relu') { + return backend.relu(x); + } + else if (activation === 'elu') { + return backend.elu(x); + } + else if (activation === 'relu6') { + return backend.relu6(x); + } + else if (activation === 'prelu') { + return backend.prelu(x, preluActivationWeights); + } + throw new Error("Activation " + activation + " has not been implemented for the CPU backend."); + } + function createCanvas$1() { + if (typeof OffscreenCanvas !== 'undefined') { + return new OffscreenCanvas(300, 150); + } + else if (typeof document !== 'undefined') { + return document.createElement('canvas'); + } + return null; + } + var MathBackendCPU = /** @class */ (function (_super) { + __extends(MathBackendCPU, _super); + function MathBackendCPU() { + var _this = _super.call(this) || this; + _this.blockSize = 48; + _this.firstUse = true; + if (env().get('IS_BROWSER')) { + var canvas = createCanvas$1(); + if (canvas !== null) { + _this.fromPixels2DContext = + canvas.getContext('2d'); + } + } + _this.data = new DataStorage(_this, ENGINE); + return _this; + } + MathBackendCPU.prototype.write = function (values, shape, dtype) { + if (this.firstUse) { + this.firstUse = false; + if (env().get('IS_NODE')) { + warn('\n============================\n' + + 'Hi there 👋. Looks like you are running TensorFlow.js in ' + + 'Node.js. To speed things up dramatically, install our node ' + + 'backend, which binds to TensorFlow C++, by running ' + + 'npm i @tensorflow/tfjs-node, ' + + 'or npm i @tensorflow/tfjs-node-gpu if you have CUDA. ' + + 'Then call require(\'@tensorflow/tfjs-node\'); (-gpu ' + + 'suffix for CUDA) at the start of your program. ' + + 'Visit https://github.com/tensorflow/tfjs-node for more details.' + + '\n============================\n'); + } + } + var dataId = {}; + this.data.set(dataId, { values: values, dtype: dtype }); + return dataId; + }; + MathBackendCPU.prototype.move = function (dataId, values, shape, dtype) { + this.data.set(dataId, { values: values, dtype: dtype }); + }; + MathBackendCPU.prototype.numDataIds = function () { + return this.data.numDataIds(); + }; + MathBackendCPU.prototype.fromPixels = function (pixels, numChannels) { + if (pixels == null) { + throw new Error('pixels passed to tf.browser.fromPixels() can not be null'); + } + var isPixelData = pixels.data instanceof Uint8Array; + var isImageData = typeof (ImageData) !== 'undefined' && pixels instanceof ImageData; + var isVideo = typeof (HTMLVideoElement) !== 'undefined' && + pixels instanceof HTMLVideoElement; + var isImage = typeof (HTMLImageElement) !== 'undefined' && + pixels instanceof HTMLImageElement; + var _a = isVideo ? + [ + pixels.videoWidth, + pixels.videoHeight + ] : + [pixels.width, pixels.height], width = _a[0], height = _a[1]; + var vals; + // tslint:disable-next-line:no-any + if (env().get('IS_NODE') && pixels.getContext == null) { + throw new Error('When running in node, pixels must be an HTMLCanvasElement ' + + 'like the one returned by the `canvas` npm package'); + } + // tslint:disable-next-line:no-any + if (pixels.getContext != null) { + // tslint:disable-next-line:no-any + vals = pixels + .getContext('2d') + .getImageData(0, 0, width, height) + .data; + } + else if (isImageData || isPixelData) { + vals = pixels.data; + } + else if (isImage || isVideo) { + if (this.fromPixels2DContext == null) { + throw new Error('Can\'t read pixels from HTMLImageElement outside ' + + 'the browser.'); + } + this.fromPixels2DContext.canvas.width = width; + this.fromPixels2DContext.canvas.height = height; + this.fromPixels2DContext.drawImage(pixels, 0, 0, width, height); + vals = this.fromPixels2DContext.getImageData(0, 0, width, height).data; + } + else { + throw new Error('pixels passed to tf.browser.fromPixels() must be either an ' + + "HTMLVideoElement, HTMLImageElement, HTMLCanvasElement, ImageData " + + "or {data: Uint32Array, width: number, height: number}, " + + ("but was " + pixels.constructor.name)); + } + var values; + if (numChannels === 4) { + values = new Int32Array(vals); + } + else { + var numPixels = width * height; + values = new Int32Array(numPixels * numChannels); + for (var i = 0; i < numPixels; i++) { + for (var channel = 0; channel < numChannels; ++channel) { + values[i * numChannels + channel] = vals[i * 4 + channel]; + } + } + } + var outShape = [height, width, numChannels]; + return tensor3d(values, outShape, 'int32'); + }; + MathBackendCPU.prototype.read = function (dataId) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.readSync(dataId)]; + }); + }); + }; + MathBackendCPU.prototype.readSync = function (dataId) { + var _a = this.data.get(dataId), dtype = _a.dtype, complexTensors = _a.complexTensors; + if (dtype === 'complex64') { + var realValues = this.readSync(complexTensors.real.dataId); + var imagValues = this.readSync(complexTensors.imag.dataId); + return mergeRealAndImagArrays(realValues, imagValues); + } + return this.data.get(dataId).values; + }; + MathBackendCPU.prototype.bufferSync = function (t) { + var data = this.readSync(t.dataId); + var decodedData = data; + if (t.dtype === 'string') { + try { + // Decode the bytes into string. + decodedData = data.map(function (d) { return decodeString(d); }); + } + catch (_a) { + throw new Error('Failed to decode encoded string bytes into utf-8'); + } + } + return buffer(t.shape, t.dtype, decodedData); + }; + MathBackendCPU.prototype.makeOutput = function (values, shape, dtype) { + var dataId = this.write(values, shape, dtype); + return ENGINE.makeTensorFromDataId(dataId, shape, dtype, this); + }; + MathBackendCPU.prototype.disposeData = function (dataId) { + if (this.data.has(dataId)) { + var complexTensors = this.data.get(dataId).complexTensors; + if (complexTensors != null) { + complexTensors.real.dispose(); + complexTensors.imag.dispose(); + } + this.data.delete(dataId); + } + }; + MathBackendCPU.prototype.time = function (f) { + return __awaiter(this, void 0, void 0, function () { + var start, kernelMs; + return __generator(this, function (_a) { + start = now(); + f(); + kernelMs = now() - start; + return [2 /*return*/, { kernelMs: kernelMs }]; + }); + }); + }; + MathBackendCPU.prototype.memory = function () { + return { + // Unreliable due to automatic gc. The numbers above are cumulative. + unreliable: true, + reasons: ['The reported memory is an upper bound. Due to automatic garbage ' + + 'collection, the true allocated memory may be less.'] + }; + }; + MathBackendCPU.prototype.complex = function (real, imag) { + var result = this.makeOutput(null, real.shape, 'complex64'); + var resultData = this.data.get(result.dataId); + // The backend owns the reference to the underlying real and imaginary + // clones. These will explicitly get disposed when the complex tensor is + // disposed. + resultData.complexTensors = { + real: ENGINE.keep(real.clone()), + imag: ENGINE.keep(imag.clone()) + }; + return result; + }; + MathBackendCPU.prototype.real = function (input) { + var resultData = this.data.get(input.dataId); + return resultData.complexTensors.real.clone(); + }; + MathBackendCPU.prototype.imag = function (input) { + var resultData = this.data.get(input.dataId); + return resultData.complexTensors.imag.clone(); + }; + MathBackendCPU.prototype.slice = function (x, begin, size) { + assertNotComplex(x, 'slice'); + var isContinous = isSliceContinous(x.shape, begin, size); + if (isContinous) { + var flatOffset = computeFlatOffset(begin, x.strides); + var length_1 = sizeFromShape(size); + var vals = this.readSync(x.dataId); + return tensor(vals.subarray(flatOffset, flatOffset + length_1), size, x.dtype); + } + var buffer$1 = buffer(size, x.dtype); + var xBuf = this.bufferSync(x); + for (var i = 0; i < buffer$1.size; ++i) { + var loc = buffer$1.indexToLoc(i); + var xLoc = loc.map(function (idx, j) { return idx + begin[j]; }); + buffer$1.values[i] = xBuf.get.apply(xBuf, xLoc); + } + return buffer$1.toTensor(); + }; + MathBackendCPU.prototype.stridedSlice = function (x, begin, end, strides) { + assertNotComplex(x, 'stridedSlice'); + var outShape = computeOutShape$2(begin, end, strides); + if (outShape.some(function (axis) { return axis === 0; })) { + return tensor([], outShape); + } + var buffer$1 = buffer(outShape, x.dtype); + var xBuf = this.bufferSync(x); + for (var i = 0; i < buffer$1.size; i++) { + var loc = buffer$1.indexToLoc(i); + var newLoc = new Array(loc.length); + for (var j = 0; j < newLoc.length; j++) { + newLoc[j] = loc[j] * strides[j] + begin[j]; + } + buffer$1.set.apply(buffer$1, [xBuf.get.apply(xBuf, newLoc)].concat(loc)); + } + return buffer$1.toTensor(); + }; + MathBackendCPU.prototype.diag = function (x) { + var xVals = this.readSync(x.dataId); + var buffer$1 = buffer([x.size, x.size], x.dtype); + var vals = buffer$1.values; + for (var i = 0; i < xVals.length; i++) { + vals[i * x.size + i] = xVals[i]; + } + return buffer$1.toTensor(); + }; + MathBackendCPU.prototype.unstack = function (x, axis) { + var num = x.shape[axis]; + var outShape = new Array(x.rank - 1); + var outIndex = 0; + for (var i = 0; i < x.rank; i++) { + if (i !== axis) { + outShape[outIndex++] = x.shape[i]; + } + } + var begin = new Array(x.rank).fill(0); + var size = x.shape.slice(); + size[axis] = 1; + var res = new Array(num); + for (var i = 0; i < res.length; i++) { + begin[axis] = i; + res[i] = this.slice(x, begin, size).reshape(outShape); + } + return res; + }; + MathBackendCPU.prototype.reverse = function (x, axis) { + assertNotComplex(x, 'reverse'); + var buffer$1 = buffer(x.shape, x.dtype); + var xBuf = this.bufferSync(x); + var _loop_1 = function (i) { + var outLoc = buffer$1.indexToLoc(i); + var inLoc = outLoc.slice(); + axis.forEach(function (ax) { return inLoc[ax] = x.shape[ax] - 1 - inLoc[ax]; }); + buffer$1.set.apply(buffer$1, [xBuf.get.apply(xBuf, inLoc)].concat(outLoc)); + }; + for (var i = 0; i < buffer$1.size; i++) { + _loop_1(i); + } + return buffer$1.toTensor(); + }; + MathBackendCPU.prototype.concat = function (tensors, axis) { + var _this = this; + if (tensors[0].dtype === 'complex64') { + var reals = tensors.map(function (t) { return real(t); }); + var imags = tensors.map(function (t) { return imag(t); }); + return complex(this.concat(reals, axis), this.concat(imags, axis)); + } + var tensors2D = tensors.map(function (t) { + var innerSize = sizeFromShape(t.shape.slice(axis)); + return t.as2D(-1, innerSize); + }); + var outShape = computeOutShape(tensors2D.map(function (t) { return t.shape; }), 1 /* axis */); + var values = buffer(outShape, tensors[0].dtype) + .values; + if (tensors2D[0].shape[0] === 1) { + // Use built-in TypedArray.set() method for speed. + var offset_1 = 0; + tensors2D.forEach(function (t) { + values.set(_this.readSync(t.dataId), offset_1); + offset_1 += t.size; + }); + } + else { + var colOffset_1 = 0; + tensors2D.forEach(function (t) { + var tVals = _this.readSync(t.dataId); + var tIdx = 0; + for (var row = 0; row < t.shape[0]; ++row) { + var resIdx = row * outShape[1] + colOffset_1; + for (var col = 0; col < t.shape[1]; ++col) { + values[resIdx + col] = tVals[tIdx++]; + } + } + colOffset_1 += t.shape[1]; + }); + } + var finalOutShape = computeOutShape(tensors.map(function (t) { return t.shape; }), axis); + return tensor(values, finalOutShape, tensors[0].dtype); + }; + MathBackendCPU.prototype.neg = function (x) { + assertNotComplex(x, 'neg'); + return this.multiply(scalar(-1), x); + }; + MathBackendCPU.prototype.add = function (a, b) { + if (a.dtype === 'complex64' || b.dtype === 'complex64') { + return this.broadcastedBinaryComplexOp(a.cast('complex64'), b.cast('complex64'), function (aReal, aImag, bReal, bImag) { + return { real: aReal + bReal, imag: aImag + bImag }; + }); + } + return this.broadcastedBinaryOp(a, b, upcastType(a.dtype, b.dtype), function (aValue, bValue) { return aValue + bValue; }); + }; + MathBackendCPU.prototype.addN = function (tensors) { + var _this = this; + assertNotComplex(tensors, 'addN'); + var vals = tensors.map(function (t) { return _this.readSync(t.dataId); }); + var result = buffer(tensors[0].shape, tensors[0].dtype); + var resultVals = result.values; + for (var i = 0; i < tensors.length; i++) { + var currVals = vals[i]; + for (var j = 0; j < resultVals.length; j++) { + resultVals[j] += currVals[j]; + } + } + return result.toTensor(); + }; + MathBackendCPU.prototype.subtract = function (a, b) { + if (a.dtype === 'complex64' || b.dtype === 'complex64') { + return this.broadcastedBinaryComplexOp(a.cast('complex64'), b.cast('complex64'), function (aReal, aImag, bReal, bImag) { + return { real: aReal - bReal, imag: aImag - bImag }; + }); + } + return this.broadcastedBinaryOp(a, b, upcastType(a.dtype, b.dtype), function (aValue, bValue) { return aValue - bValue; }); + }; + MathBackendCPU.prototype.pow = function (a, b) { + assertNotComplex([a, b], 'pow'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aValue, bValue) { return Math.pow(aValue, bValue); }); + }; + MathBackendCPU.prototype.batchMatMul = function (a, b, transposeA, transposeB) { + assertNotComplex([a, b], 'matMul'); + var sharedDim = transposeA ? a.shape[1] : a.shape[2]; + var leftDim = transposeA ? a.shape[2] : a.shape[1]; + var rightDim = transposeB ? b.shape[1] : b.shape[2]; + var batchDim = a.shape[0]; + var aValues = this.readSync(a.dataId); + var bValues = this.readSync(b.dataId); + var _a = transposeA ? + [a.strides[0], 1, a.strides[1]] : + [a.strides[0], a.strides[1], 1], aBatch = _a[0], aOuterStep = _a[1], aInnerStep = _a[2]; + var _b = transposeB ? + [1, b.strides[1], b.strides[0]] : + [b.strides[1], 1, b.strides[0]], bInnerStep = _b[0], bOuterStep = _b[1], bBatch = _b[2]; + var size = leftDim * rightDim; + var result = buffer([batchDim, leftDim, rightDim], a.dtype); + var resVals = result.values; + var blockSize = this.blockSize; + for (var b_1 = 0; b_1 < batchDim; b_1++) { + for (var i0 = 0; i0 < leftDim; i0 += blockSize) { + for (var j0 = 0; j0 < rightDim; j0 += blockSize) { + for (var k0 = 0; k0 < sharedDim; k0 += blockSize) { + // for when blockSize doesn't evenly divide the input + var iBlock = Math.min(i0 + blockSize, leftDim); + var jBlock = Math.min(j0 + blockSize, rightDim); + var kBlock = Math.min(k0 + blockSize, sharedDim); + for (var i = i0; i < iBlock; i++) { + for (var j = j0; j < jBlock; j++) { + var sum = 0.0; + for (var k = k0; k < kBlock; k++) { + sum += aValues[b_1 * aBatch + i * aOuterStep + k * aInnerStep] * + bValues[k * bInnerStep + j * bOuterStep + b_1 * bBatch]; + } + resVals[b_1 * size + (i * rightDim + j)] += sum; + } + } + } + } + } + } + return result.toTensor(); + }; + MathBackendCPU.prototype.fusedBatchMatMul = function (_a) { + var a = _a.a, b = _a.b, transposeA = _a.transposeA, transposeB = _a.transposeB, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var result = this.batchMatMul(a, b, transposeA, transposeB); + if (bias) { + result = this.add(result, bias); + } + if (activation) { + result = + mapActivation(this, result, activation, preluActivationWeights); + } + return result; + }; + MathBackendCPU.prototype.multiply = function (a, b) { + if (a.dtype === 'complex64' || b.dtype === 'complex64') { + return this.broadcastedBinaryComplexOp(a.cast('complex64'), b.cast('complex64'), function (aReal, aImag, bReal, bImag) { + return { + real: aReal * bReal - aImag * bImag, + imag: aReal * bImag + aImag * bReal + }; + }); + } + return this.broadcastedBinaryOp(a, b, upcastType(a.dtype, b.dtype), function (aValue, bValue) { return aValue * bValue; }); + }; + MathBackendCPU.prototype.realDivide = function (a, b) { + assertNotComplex([a, b], 'realDivide'); + var op = function (a, b) { return a / b; }; + var outputDtype = 'float32'; + return this.broadcastedBinaryOp(a, b, outputDtype, op); + }; + MathBackendCPU.prototype.floorDiv = function (a, b) { + assertNotComplex([a, b], 'floorDiv'); + var op = function (a, b) { return Math.floor(a / b); }; + var outputDtype = 'int32'; + return this.broadcastedBinaryOp(a, b, outputDtype, op); + }; + MathBackendCPU.prototype.sum = function (x, axes) { + assertNotComplex(x, 'sum'); + assertAxesAreInnerMostDims('sum', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var resultDtype = upcastType(x.dtype, 'int32'); + var result = zeros(outShape, resultDtype); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var sum = 0; + for (var j = 0; j < reduceSize; ++j) { + sum += aVals[offset + j]; + } + vals[i] = sum; + } + return result; + }; + MathBackendCPU.prototype.prod = function (x, axes) { + assertNotComplex(x, 'sum'); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var resultDtype = upcastType(x.dtype, 'int32'); + var result = zeros(outShape, resultDtype); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var prod = 1; + for (var j = 0; j < reduceSize; ++j) { + prod *= aVals[offset + j]; + } + vals[i] = prod; + } + return result; + }; + MathBackendCPU.prototype.unsortedSegmentSum = function (x, segmentIds, numSegments) { + assertNotComplex(x, 'unsortedSegmentSum'); + var res = []; + // Reshape the segment id's so that they can be broadcast with + // x. The new shape should be [segmentIds.shape, 1, ..., 1] + var numIters = x.rank - segmentIds.rank; + for (var i = 0; i < numIters; ++i) { + segmentIds = segmentIds.expandDims(i + 1); + } + for (var i = 0; i < numSegments; ++i) { + var segmentId = scalar(i, 'int32'); + var mask = equal(segmentId, segmentIds).asType('float32'); + var sum = mask.mul(x).sum(0); + res.push(sum); + } + return stack(res); + }; + MathBackendCPU.prototype.argMin = function (x, axis) { + assertNotComplex(x, 'argMin'); + var axes = [axis]; + assertAxesAreInnerMostDims('argMin', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = zeros(outShape, 'int32'); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var min = aVals[offset]; + var minIndex = 0; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value < min) { + min = value; + minIndex = j; + } + } + vals[i] = minIndex; + } + return result; + }; + MathBackendCPU.prototype.argMax = function (x, axis) { + assertNotComplex(x, 'argMax'); + var axes = [axis]; + assertAxesAreInnerMostDims('argMax', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = zeros(outShape, 'int32'); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var max = aVals[offset]; + var maxIndex = 0; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value > max) { + max = value; + maxIndex = j; + } + } + vals[i] = maxIndex; + } + return result; + }; + MathBackendCPU.prototype.cumsum = function (x, axis, exclusive, reverse) { + assertNotComplex(x, 'cumsum'); + if (axis !== x.rank - 1) { + throw new Error("backend.cumsum in CPU expects an inner-most axis=" + (x.rank - 1) + " " + + ("but got axis=" + axis)); + } + var resultDtype = upcastType(x.dtype, 'int32'); + var result = zeros(x.shape, resultDtype); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + var finalDim = x.shape[x.rank - 1]; + var indexAdjuster = reverse ? + function (i, j) { return i + finalDim - j - 1; } : + function (i, j) { return i + j; }; + for (var i = 0; i < aVals.length; i += finalDim) { + for (var j = 0; j < finalDim; j++) { + var idx = indexAdjuster(i, j); + if (j === 0) { + vals[idx] = exclusive ? 0 : aVals[idx]; + } + else { + var prevIdx = indexAdjuster(i, j - 1); + vals[idx] = exclusive ? aVals[prevIdx] + vals[prevIdx] : + aVals[idx] + vals[prevIdx]; + } + } + } + return result; + }; + MathBackendCPU.prototype.equal = function (a, b) { + assertNotComplex([a, b], 'equal'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal === bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.notEqual = function (a, b) { + assertNotComplex([a, b], 'notEqual'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal !== bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.less = function (a, b) { + assertNotComplex([a, b], 'less'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal < bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.lessEqual = function (a, b) { + assertNotComplex([a, b], 'lessEqual'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal <= bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.greater = function (a, b) { + assertNotComplex([a, b], 'greater'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal > bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.greaterEqual = function (a, b) { + assertNotComplex([a, b], 'greaterEqual'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return (aVal >= bVal) ? 1 : 0; + }); + }; + MathBackendCPU.prototype.logicalNot = function (x) { + assertNotComplex(x, 'logicalNot'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = values[i] ? 0 : 1; + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.logicalAnd = function (a, b) { + assertNotComplex([a, b], 'logicalAnd'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return aVal && bVal; + }); + }; + MathBackendCPU.prototype.logicalOr = function (a, b) { + assertNotComplex([a, b], 'logicalOr'); + return this.broadcastedBinaryOp(a, b, 'bool', function (aVal, bVal) { + return aVal || bVal; + }); + }; + MathBackendCPU.prototype.select = function (condition, a, b) { + assertNotComplex([condition, a, b], 'select'); + var values = this.readSync(condition.dataId); + var aValues = this.readSync(a.dataId); + var bValues = this.readSync(b.dataId); + var result = zeros(a.shape, upcastType(a.dtype, b.dtype)); + var newValues = this.readSync(result.dataId); + var index = 0; + var offset = condition.rank === 0 || condition.rank > 1 || a.rank === 1 ? + 1 : + sizeFromShape(a.shape.slice(1)); + for (var i = 0; i < values.length; i++) { + for (var j = 0; j < offset; j++) { + if (values[i] === 1) { + newValues[index++] = aValues[i]; + } + else { + newValues[index++] = bValues[i]; + } + } + } + return result; + }; + MathBackendCPU.prototype.where = function (condition) { + assertNotComplex([condition], 'where'); + var condVals = this.readSync(condition.dataId); + return whereImpl(condition.shape, condVals); + }; + MathBackendCPU.prototype.topk = function (x, k, sorted) { + assertNotComplex(x, 'topk'); + var xVals = this.readSync(x.dataId); + return topkImpl(xVals, x.shape, x.dtype, k, sorted); + }; + MathBackendCPU.prototype.min = function (x, axes) { + assertNotComplex(x, 'min'); + assertAxesAreInnerMostDims('min', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = zeros(outShape, x.dtype); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var min = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value < min) { + min = value; + } + } + vals[i] = min; + } + return result; + }; + MathBackendCPU.prototype.minimum = function (a, b) { + assertNotComplex([a, b], 'minimum'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { return Math.min(aVal, bVal); }); + }; + MathBackendCPU.prototype.mod = function (a, b) { + assertNotComplex([a, b], 'mod'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { + var rem = aVal % bVal; + if ((aVal < 0 && bVal < 0) || (aVal >= 0 && bVal >= 0)) { + return rem; + } + else { + return (rem + bVal) % bVal; + } + }); + }; + MathBackendCPU.prototype.max = function (x, axes) { + assertNotComplex(x, 'max'); + assertAxesAreInnerMostDims('max', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = zeros(outShape, x.dtype); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var max = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + if (value > max) { + max = value; + } + } + vals[i] = max; + } + return result; + }; + MathBackendCPU.prototype.maximum = function (a, b) { + assertNotComplex([a, b], 'maximum'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { return Math.max(aVal, bVal); }); + }; + MathBackendCPU.prototype.all = function (x, axes) { + assertNotComplex(x, 'all'); + assertAxesAreInnerMostDims('all', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = zeros(outShape, x.dtype); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var all = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + all = all && value; + } + vals[i] = all; + } + return result; + }; + MathBackendCPU.prototype.any = function (x, axes) { + assertNotComplex(x, 'any'); + assertAxesAreInnerMostDims('any', axes, x.rank); + var _a = computeOutAndReduceShapes(x.shape, axes), outShape = _a[0], reduceShape = _a[1]; + var result = zeros(outShape, x.dtype); + var reduceSize = sizeFromShape(reduceShape); + var vals = this.readSync(result.dataId); + var aVals = this.readSync(x.dataId); + for (var i = 0; i < vals.length; ++i) { + var offset = i * reduceSize; + var anyVal = aVals[offset]; + for (var j = 0; j < reduceSize; ++j) { + var value = aVals[offset + j]; + anyVal = anyVal || value; + } + vals[i] = anyVal; + } + return result; + }; + MathBackendCPU.prototype.squaredDifference = function (a, b) { + assertNotComplex([a, b], 'squaredDifference'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aVal, bVal) { + var diff = aVal - bVal; + return diff * diff; + }); + }; + MathBackendCPU.prototype.ceil = function (x) { + assertNotComplex(x, 'ceil'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.ceil(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.floor = function (x) { + assertNotComplex(x, 'floor'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.floor(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.sign = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (values[i] < 0) { + newValues[i] = -1; + } + else if (values[i] > 0) { + newValues[i] = 1; + } + else { + newValues[i] = 0; + } + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.isNaN = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (Number.isNaN(values[i])) { + newValues[i] = 1; + } + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.isInf = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (Math.abs(values[i]) === Infinity) { + newValues[i] = 1; + } + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.isFinite = function (x) { + assertNotComplex(x, 'x'); + var values = this.readSync(x.dataId); + var newValues = new Uint8Array(values.length); + for (var i = 0; i < values.length; ++i) { + if (Number.isFinite(values[i])) { + newValues[i] = 1; + } + } + return this.makeOutput(newValues, x.shape, 'bool'); + }; + MathBackendCPU.prototype.round = function (x) { + assertNotComplex(x, 'round'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + // The algorithm is based on banker's rounding. + var base = Math.floor(values[i]); + if (values[i] - base < 0.5) { + newValues[i] = Math.floor(values[i]); + } + else if (values[i] - base > 0.5) { + newValues[i] = Math.ceil(values[i]); + } + else { + if (base % 2.0 === 0.0) { + newValues[i] = base; + } + else { + newValues[i] = base + 1.0; + } + } + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.exp = function (x) { + assertNotComplex(x, 'exp'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.exp(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.expm1 = function (x) { + assertNotComplex(x, 'expm1'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = Math.expm1(values[i]); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.log = function (x) { + assertNotComplex(x, 'log'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = Math.log(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.log1p = function (x) { + assertNotComplex(x, 'log1p'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = Math.log1p(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.sqrt = function (x) { + assertNotComplex(x, 'sqrt'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = Math.sqrt(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.rsqrt = function (x) { + assertNotComplex(x, 'rsqrt'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = 1 / Math.sqrt(value); + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.reciprocal = function (x) { + assertNotComplex(x, 'reciprocal'); + var values = this.readSync(x.dataId); + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + newValues[i] = 1 / values[i]; + } + return this.makeOutput(newValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.linear = function (x) { + return x; + }; + MathBackendCPU.prototype.relu = function (x) { + assertNotComplex(x, 'relu'); + var res = zeros(x.shape, x.dtype); + var resVals = this.readSync(res.dataId); + var inVals = this.readSync(x.dataId); + for (var i = 0; i < inVals.length; ++i) { + resVals[i] = Math.max(0, inVals[i]); + } + return res; + }; + MathBackendCPU.prototype.relu6 = function (x) { + assertNotComplex(x, 'relu'); + var res = zeros(x.shape, x.dtype); + var resVals = this.readSync(res.dataId); + var inVals = this.readSync(x.dataId); + for (var i = 0; i < inVals.length; ++i) { + resVals[i] = Math.min(Math.max(0, inVals[i]), 6); + } + return res; + }; + MathBackendCPU.prototype.prelu = function (x, a) { + assertNotComplex([x, a], 'prelu'); + return this.broadcastedBinaryOp(x, a, x.dtype, function (xValue, aValue) { return xValue < 0 ? aValue * xValue : xValue; }); + }; + MathBackendCPU.prototype.elu = function (x) { + assertNotComplex(x, 'elu'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + if (v >= 0) { + resultValues[i] = v; + } + else { + resultValues[i] = (Math.exp(v) - 1); + } + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.eluDer = function (dy, y) { + assertNotComplex([dy, y], 'eluDer'); + var resultValues = new Float32Array(y.size); + var values = this.readSync(y.dataId); + var dyValues = this.readSync(dy.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + if (v >= 1) { + resultValues[i] = dyValues[i]; + } + else { + resultValues[i] = dyValues[i] * (v + 1); + } + } + return this.makeOutput(resultValues, y.shape, 'float32'); + }; + MathBackendCPU.prototype.selu = function (x) { + assertNotComplex(x, 'selu'); + // Stable and Attracting Fixed Point (0, 1) for Normalized Weights. + // see: https://arxiv.org/abs/1706.02515 + var scaleAlpha = SELU_SCALEALPHA; + var scale = SELU_SCALE; + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + if (v >= 0) { + resultValues[i] = scale * v; + } + else { + resultValues[i] = scaleAlpha * (Math.exp(v) - 1); + } + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.clip = function (x, min, max) { + assertNotComplex(x, 'clip'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var v = values[i]; + resultValues[i] = v > max ? max : (v < min ? min : v); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.abs = function (x) { + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.abs(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.complexAbs = function (x) { + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < x.size; ++i) { + var real_1 = values[i * 2]; + var imag_1 = values[i * 2 + 1]; + resultValues[i] = Math.hypot(real_1, imag_1); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.int = function (x) { + assertNotComplex(x, 'int'); + var resultValues = new Int32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = values[i]; + } + return this.makeOutput(resultValues, x.shape, 'int32'); + }; + MathBackendCPU.prototype.sigmoid = function (x) { + assertNotComplex(x, 'sigmoid'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = 1 / (1 + Math.exp(-values[i])); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.softplus = function (x) { + assertNotComplex(x, 'softplus'); + // mirrors the implementation of tf.nn.softplus: https://goo.gl/vkcvwX + // epsilon is the difference between 1.0 and the next representable float. + // For a single precision 32 bit float this should be 2^-23, see: + // https://math.byu.edu/~schow/work/IEEEFloatingPoint.htm + var epsilon = 1.1920928955078125e-7; + var threshold = Math.log(epsilon) + 2.0; + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + // Value above which exp(x) may overflow, but softplus(x) == x + // is within machine epsilon. + var tooLarge = values[i] > -threshold; + // Value below which exp(x) may underflow, but softplus(x) == exp(x) + // is within machine epsilon. + var tooSmall = values[i] < threshold; + var expX = Math.exp(values[i]); + var result = void 0; + if (tooSmall) { + result = expX; + } + else if (tooLarge) { + result = values[i]; + } + else { + result = Math.log(1.0 + expX); + } + resultValues[i] = result; + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.sin = function (x) { + assertNotComplex(x, 'sin'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.sin(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.cos = function (x) { + assertNotComplex(x, 'cos'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.cos(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.tan = function (x) { + assertNotComplex(x, 'tan'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.tan(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.asin = function (x) { + assertNotComplex(x, 'asin'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.asin(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.acos = function (x) { + assertNotComplex(x, 'acos'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.acos(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.atan = function (x) { + assertNotComplex(x, 'atan'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.atan(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.atan2 = function (a, b) { + assertNotComplex([a, b], 'atan2'); + return this.broadcastedBinaryOp(a, b, a.dtype, function (aValue, bValue) { return Math.atan2(aValue, bValue); }); + }; + MathBackendCPU.prototype.sinh = function (x) { + assertNotComplex(x, 'sinh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.sinh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.cosh = function (x) { + assertNotComplex(x, 'cosh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.cosh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.tanh = function (x) { + assertNotComplex(x, 'tanh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = tanh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.asinh = function (x) { + assertNotComplex(x, 'asinh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.asinh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.acosh = function (x) { + assertNotComplex(x, 'acosh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.acosh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.atanh = function (x) { + assertNotComplex(x, 'atanh'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + resultValues[i] = Math.atanh(values[i]); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.erf = function (x) { + assertNotComplex(x, 'erf'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + var p = ERF_P; + var a1 = ERF_A1; + var a2 = ERF_A2; + var a3 = ERF_A3; + var a4 = ERF_A4; + var a5 = ERF_A5; + for (var i = 0; i < values.length; ++i) { + var sign = Math.sign(values[i]); + var v = Math.abs(values[i]); + var t = 1.0 / (1.0 + p * v); + resultValues[i] = sign * + (1.0 - + (((((a5 * t + a4) * t) + a3) * t + a2) * t + a1) * t * + Math.exp(-v * v)); + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.step = function (x, alpha) { + if (alpha === void 0) { alpha = 0; } + assertNotComplex(x, 'step'); + var resultValues = new Float32Array(x.size); + var values = this.readSync(x.dataId); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + if (isNaN(value)) { + resultValues[i] = NaN; + } + else { + resultValues[i] = value > 0 ? 1 : alpha; + } + } + return this.makeOutput(resultValues, x.shape, 'float32'); + }; + MathBackendCPU.prototype.fusedConv2d = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var result = this.conv2d(input, filter, convInfo); + if (bias) { + result = this.add(result, bias); + } + if (activation) { + result = + mapActivation(this, result, activation, preluActivationWeights); + } + return result; + }; + MathBackendCPU.prototype.conv2d = function (x, filter, convInfo) { + assertNotComplex([x, filter], 'conv2d'); + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var padLeft = convInfo.padInfo.left; + var padTop = convInfo.padInfo.top; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + var y = buffer(convInfo.outShape, x.dtype); + var xBatchStride = x.strides[0]; + var xRowStride = isChannelsLast ? x.strides[1] : x.strides[2]; + var xColStride = isChannelsLast ? x.strides[2] : 1; + var xChannelStride = isChannelsLast ? 1 : x.strides[1]; + var yBatchStride = y.strides[0]; + var yRowStride = isChannelsLast ? y.strides[1] : y.strides[2]; + var yColStride = isChannelsLast ? y.strides[2] : 1; + var yChannelStride = isChannelsLast ? 1 : y.strides[1]; + var xVals = this.readSync(x.dataId); + var wVals = this.readSync(filter.dataId); + var yVals = y.values; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * xBatchStride; + var yOffset1 = b * yBatchStride; + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var yOffset2 = yOffset1 + yR * yRowStride; + var xRCorner = yR * convInfo.strideHeight - padTop; + for (var wR = 0; wR < filterHeight; wR++) { + var xR = xRCorner + wR * dilationHeight; + if (xR < 0 || xR >= convInfo.inHeight) { + continue; + } + var wOffset1 = wR * filter.strides[0]; + var xOffset2 = xOffset1 + xR * xRowStride; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var yOffset3 = yOffset2 + yC * yColStride; + var xCCorner = yC * convInfo.strideWidth - padLeft; + for (var wC = 0; wC < filterWidth; wC++) { + var xC = xCCorner + wC * dilationWidth; + if (xC < 0 || xC >= convInfo.inWidth) { + continue; + } + var wOffset2 = wOffset1 + wC * filter.strides[1]; + var xOffset3 = xOffset2 + xC * xColStride; + var wOffset3 = wOffset2; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var xVal = xVals[xOffset3 + d1 * xChannelStride]; + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + yVals[yOffset3 + d2 * yChannelStride] += + xVal * wVals[wOffset3 + d2]; + } + wOffset3 += convInfo.outChannels; + } + } + } + } + } + } + return y.toTensor(); + }; + MathBackendCPU.prototype.conv3d = function (x, filter, convInfo) { + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var padFront = convInfo.padInfo.front; + var padLeft = convInfo.padInfo.left; + var padTop = convInfo.padInfo.top; + var y = buffer(convInfo.outShape, x.dtype); + var xVals = this.readSync(x.dataId); + var wVals = this.readSync(filter.dataId); + var yVals = y.values; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * x.strides[0]; + var yOffset1 = b * y.strides[0]; + for (var yF = 0; yF < convInfo.outDepth; ++yF) { + var yOffset2 = yOffset1 + yF * y.strides[1]; + var xFCorner = yF * convInfo.strideDepth - padFront; + for (var wF = 0; wF < filterDepth; wF++) { + var xF = xFCorner + wF * dilationDepth; + if (xF < 0 || xF >= convInfo.inDepth) { + continue; + } + var wOffset1 = wF * filter.strides[0]; + var xOffset2 = xOffset1 + xF * x.strides[1]; + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var yOffset3 = yOffset2 + yR * y.strides[2]; + var xRCorner = yR * convInfo.strideHeight - padTop; + for (var wR = 0; wR < filterHeight; wR++) { + var xR = xRCorner + wR * dilationHeight; + if (xR < 0 || xR >= convInfo.inHeight) { + continue; + } + var wOffset2 = wOffset1 + wR * filter.strides[1]; + var xOffset3 = xOffset2 + xR * x.strides[2]; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var yOffset4 = yOffset3 + yC * convInfo.outChannels; + var xCCorner = yC * convInfo.strideWidth - padLeft; + for (var wC = 0; wC < filterWidth; wC++) { + var xC = xCCorner + wC * dilationWidth; + if (xC < 0 || xC >= convInfo.inWidth) { + continue; + } + var wOffset3 = wOffset2 + wC * filter.strides[2]; + var xOffset4 = xOffset3 + xC * convInfo.inChannels; + var wOffset4 = wOffset3; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var xVal = xVals[xOffset4 + d1]; + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + yVals[yOffset4 + d2] += xVal * wVals[wOffset4 + d2]; + } + wOffset4 += convInfo.outChannels; + } + } + } + } + } + } + } + } + return y.toTensor(); + }; + MathBackendCPU.prototype.conv2dDerInput = function (dy, filter, convInfo) { + assertNotComplex([dy, filter], 'conv2dDerInput'); + var dx = buffer(convInfo.inShape, 'float32'); + var dxValues = dx.values; + var dyValues = this.readSync(dy.dataId); + var fltValues = this.readSync(filter.dataId); + var _a = filter.strides, fltS0 = _a[0], fltS1 = _a[1], fltS2 = _a[2]; + var batchSize = convInfo.batchSize, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth, dataFormat = convInfo.dataFormat; + var topPad = filterHeight - 1 - convInfo.padInfo.top; + var leftPad = filterWidth - 1 - convInfo.padInfo.left; + var isChannelsLast = dataFormat === 'channelsLast'; + var xBatchStride = dx.strides[0]; + var xRowStride = isChannelsLast ? dx.strides[1] : dx.strides[2]; + var xColStride = isChannelsLast ? dx.strides[2] : 1; + var xChannelStride = isChannelsLast ? 1 : dx.strides[1]; + var yBatchStride = dy.strides[0]; + var yRowStride = isChannelsLast ? dy.strides[1] : dy.strides[2]; + var yColStride = isChannelsLast ? dy.strides[2] : 1; + var yChannelStride = isChannelsLast ? 1 : dy.strides[1]; + for (var b = 0; b < batchSize; ++b) { + for (var d1 = 0; d1 < inChannels; ++d1) { + for (var xR = 0; xR < inHeight; ++xR) { + var xRCorner = xR - topPad; + var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight)); + var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight); + for (var xC = 0; xC < inWidth; ++xC) { + var xCCorner = xC - leftPad; + var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth)); + var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth); + var dotProd = 0; + for (var yR = xRMin; yR < yRMax; ++yR) { + var wR = yR * strideHeight - xRCorner; + for (var yC = xCMin; yC < yCMax; ++yC) { + var wC = yC * strideWidth - xCCorner; + var dyOffset = yBatchStride * b + yRowStride * yR + yColStride * yC; + var fltOffset = fltS0 * (filterHeight - 1 - wR) + + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1; + for (var d2 = 0; d2 < outChannels; ++d2) { + var pixel = dyValues[dyOffset + yChannelStride * d2]; + var weight = fltValues[fltOffset + d2]; + dotProd += pixel * weight; + } + } + } + var dxOffset = xBatchStride * b + xRowStride * xR + + xColStride * xC + xChannelStride * d1; + dxValues[dxOffset] = dotProd; + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.conv3dDerInput = function (dy, filter, convInfo) { + var dx = buffer(convInfo.inShape, 'float32'); + var dxValues = dx.values; + var _a = dx.strides, dxS0 = _a[0], dxS1 = _a[1], dxS2 = _a[2], dxS3 = _a[3]; + var dyValues = this.readSync(dy.dataId); + var _b = dy.strides, dyS0 = _b[0], dyS1 = _b[1], dyS2 = _b[2], dyS3 = _b[3]; + var fltValues = this.readSync(filter.dataId); + var _c = filter.strides, fltS0 = _c[0], fltS1 = _c[1], fltS2 = _c[2], fltS3 = _c[3]; + var batchSize = convInfo.batchSize, filterDepth = convInfo.filterDepth, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inDepth = convInfo.inDepth, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outDepth = convInfo.outDepth, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideDepth = convInfo.strideDepth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth; + var frontPad = filterDepth - 1 - convInfo.padInfo.front; + var topPad = filterHeight - 1 - convInfo.padInfo.top; + var leftPad = filterWidth - 1 - convInfo.padInfo.left; + for (var b = 0; b < batchSize; ++b) { + for (var d1 = 0; d1 < inChannels; ++d1) { + // Frames of depth + for (var xF = 0; xF < inDepth; ++xF) { + var xFCorner = xF - frontPad; + var xFMin = Math.max(0, Math.ceil(xFCorner / strideDepth)); + var yFMax = Math.min(outDepth, (filterDepth + xFCorner) / strideDepth); + // Rows as per standard 2d matrix notation + for (var xR = 0; xR < inHeight; ++xR) { + var xRCorner = xR - topPad; + var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight)); + var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight); + // Columns as per standard 2d matrix notation + for (var xC = 0; xC < inWidth; ++xC) { + var xCCorner = xC - leftPad; + var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth)); + var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth); + var dotProd = 0; + for (var yF = xFMin; yF < yFMax; ++yF) { + var wF = yF * strideDepth - xFCorner; + for (var yR = xRMin; yR < yRMax; ++yR) { + var wR = yR * strideHeight - xRCorner; + for (var yC = xCMin; yC < yCMax; ++yC) { + var wC = yC * strideWidth - xCCorner; + var dyOffset = dyS0 * b + dyS1 * yF + dyS2 * yR + dyS3 * yC; + var fltOffset = fltS0 * (filterDepth - 1 - wF) + + fltS1 * (filterHeight - 1 - wR) + + fltS2 * (filterWidth - 1 - wC) + fltS3 * d1; + for (var d2 = 0; d2 < outChannels; ++d2) { + var pixel = dyValues[dyOffset + d2]; + var weight = fltValues[fltOffset + d2]; + dotProd += pixel * weight; + } + } + } + } + dxValues[dxS0 * b + dxS1 * xF + dxS2 * xR + dxS3 * xC + d1] = + dotProd; + } + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.conv2dDerFilter = function (x, dy, convInfo) { + assertNotComplex([x, dy], 'conv2dDerFilter'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var isChannelsLast = convInfo.dataFormat === 'channelsLast'; + var dW = buffer(convInfo.filterShape, 'float32'); + var leftPad = convInfo.padInfo.left; + var topPad = convInfo.padInfo.top; + var xBuf = this.bufferSync(x); + var dyBuf = this.bufferSync(dy); + for (var wR = 0; wR < filterHeight; ++wR) { + var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight)); + var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight); + for (var wC = 0; wC < filterWidth; ++wC) { + var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth)); + var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth); + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + // Need to convolve. + var dotProd = 0; + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var yR = yRMin; yR < yRMax; ++yR) { + var xR = wR + yR * strideHeight - topPad; + for (var yC = yCMin; yC < yCMax; ++yC) { + var xC = wC + yC * strideWidth - leftPad; + if (isChannelsLast) { + dotProd += + xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2); + } + else { + dotProd += + xBuf.get(b, d1, xR, xC) * dyBuf.get(b, d2, yR, yC); + } + } + } + } + dW.set(dotProd, wR, wC, d1, d2); + } + } + } + } + return dW.toTensor(); + }; + MathBackendCPU.prototype.conv3dDerFilter = function (x, dy, convInfo) { + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dw = buffer(convInfo.filterShape, 'float32'); + var dwValues = dw.values; + var _a = dw.strides, dwS0 = _a[0], dwS1 = _a[1], dwS2 = _a[2], dwS3 = _a[3]; + var dyValues = this.readSync(dy.dataId); + var _b = dy.strides, dyS0 = _b[0], dyS1 = _b[1], dyS2 = _b[2], dyS3 = _b[3]; + var xValues = this.readSync(x.dataId); + var _c = x.strides, xS0 = _c[0], xS1 = _c[1], xS2 = _c[2], xS3 = _c[3]; + var frontPad = convInfo.padInfo.front; + var leftPad = convInfo.padInfo.left; + var topPad = convInfo.padInfo.top; + for (var wF = 0; wF < filterDepth; ++wF) { + var yFMin = Math.max(0, Math.ceil((frontPad - wF) / strideDepth)); + var yFMax = Math.min(convInfo.outDepth, (convInfo.inDepth + frontPad - wF) / strideDepth); + var wOffset1 = wF * dwS0; + for (var wR = 0; wR < filterHeight; ++wR) { + var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight)); + var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight); + var wOffset2 = wR * dwS1 + wOffset1; + for (var wC = 0; wC < filterWidth; ++wC) { + var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth)); + var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth); + var wOffset3 = wC * dwS2 + wOffset2; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var wOffset4 = d1 * dwS3 + wOffset3; + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + var dotProd = 0; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * xS0; + var yOffset1 = b * dyS0; + for (var yF = yFMin; yF < yFMax; ++yF) { + var xF = wF + yF * strideDepth - frontPad; + var xOffset2 = xF * xS1 + xOffset1; + var yOffset2 = yF * dyS1 + yOffset1; + for (var yR = yRMin; yR < yRMax; ++yR) { + var xR = wR + yR * strideHeight - topPad; + var xOffset3 = xR * xS2 + xOffset2; + var yOffset3 = yR * dyS2 + yOffset2; + for (var yC = yCMin; yC < yCMax; ++yC) { + var xC = wC + yC * strideWidth - leftPad; + var xOffset4 = xC * xS3 + xOffset3; + var yOffset4 = yC * dyS3 + yOffset3; + dotProd += + xValues[xOffset4 + d1] * dyValues[yOffset4 + d2]; + } + } + } + } + dwValues[wOffset4 + d2] = dotProd; + } + } + } + } + } + return dw.toTensor(); + }; + MathBackendCPU.prototype.fusedDepthwiseConv2D = function (_a) { + var input = _a.input, filter = _a.filter, convInfo = _a.convInfo, bias = _a.bias, activation = _a.activation, preluActivationWeights = _a.preluActivationWeights; + var result = this.depthwiseConv2D(input, filter, convInfo); + if (bias) { + result = this.add(result, bias); + } + if (activation) { + result = + mapActivation(this, result, activation, preluActivationWeights); + } + return result; + }; + MathBackendCPU.prototype.depthwiseConv2D = function (x, filter, convInfo) { + assertNotComplex([x, filter], 'depthwiseConv2D'); + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var padLeft = convInfo.padInfo.left; + var padTop = convInfo.padInfo.top; + var chMul = convInfo.outChannels / convInfo.inChannels; + var y = buffer(convInfo.outShape, x.dtype); + var xVals = this.readSync(x.dataId); + var wVals = this.readSync(filter.dataId); + var yVals = y.values; + for (var b = 0; b < convInfo.batchSize; ++b) { + var xOffset1 = b * x.strides[0]; + var yOffset1 = b * y.strides[0]; + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var yOffset2 = yOffset1 + yR * y.strides[1]; + var xRCorner = yR * convInfo.strideHeight - padLeft; + for (var wR = 0; wR < filterHeight; ++wR) { + var xR = xRCorner + wR * dilationHeight; + if (xR < 0 || xR >= convInfo.inHeight) { + continue; + } + var wOffset1 = wR * filter.strides[0]; + var xOffset2 = xOffset1 + xR * x.strides[1]; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var yOffset3 = yOffset2 + yC * y.strides[2]; + var xCCorner = yC * convInfo.strideWidth - padTop; + for (var wC = 0; wC < filterWidth; ++wC) { + var xC = xCCorner + wC * dilationWidth; + if (xC < 0 || xC >= convInfo.inWidth) { + continue; + } + var wOffset2 = wOffset1 + wC * filter.strides[1]; + var xOffset3 = xOffset2 + xC * convInfo.inChannels; + var yOffset4 = yOffset3; + var wOffset3 = wOffset2; + for (var d1 = 0; d1 < convInfo.inChannels; ++d1) { + var xVal = xVals[xOffset3 + d1]; + for (var q = 0; q < chMul; ++q) { + yVals[yOffset4 + q] += xVal * wVals[wOffset3 + q]; + } + yOffset4 += chMul; + wOffset3 += chMul; + } + } + } + } + } + } + return y.toTensor(); + }; + MathBackendCPU.prototype.depthwiseConv2DDerInput = function (dy, filter, convInfo) { + assertNotComplex([dy, filter], 'depthwiseConv2DDerInput'); + var dx = buffer(convInfo.inShape, 'float32'); + var dxValues = dx.values; + var _a = dx.strides, dxS0 = _a[0], dxS1 = _a[1], dxS2 = _a[2]; + var dyValues = this.readSync(dy.dataId); + var _b = dy.strides, dyS0 = _b[0], dyS1 = _b[1], dyS2 = _b[2]; + var fltValues = this.readSync(filter.dataId); + var _c = filter.strides, fltS0 = _c[0], fltS1 = _c[1], fltS2 = _c[2]; + var batchSize = convInfo.batchSize, filterHeight = convInfo.filterHeight, filterWidth = convInfo.filterWidth, inChannels = convInfo.inChannels, inHeight = convInfo.inHeight, inWidth = convInfo.inWidth, outChannels = convInfo.outChannels, outHeight = convInfo.outHeight, outWidth = convInfo.outWidth, strideHeight = convInfo.strideHeight, strideWidth = convInfo.strideWidth; + var topPad = filterHeight - 1 - convInfo.padInfo.top; + var leftPad = filterWidth - 1 - convInfo.padInfo.left; + var chMul = outChannels / inChannels; + for (var b = 0; b < batchSize; ++b) { + for (var d1 = 0; d1 < inChannels; ++d1) { + for (var xR = 0; xR < inHeight; ++xR) { + var xRCorner = xR - topPad; + var xRMin = Math.max(0, Math.ceil(xRCorner / strideHeight)); + var yRMax = Math.min(outHeight, (filterHeight + xRCorner) / strideHeight); + for (var xC = 0; xC < inWidth; ++xC) { + var xCCorner = xC - leftPad; + var xCMin = Math.max(0, Math.ceil(xCCorner / strideWidth)); + var yCMax = Math.min(outWidth, (filterWidth + xCCorner) / strideWidth); + var dotProd = 0; + for (var yR = xRMin; yR < yRMax; ++yR) { + var wR = yR * strideHeight - xRCorner; + for (var yC = xCMin; yC < yCMax; ++yC) { + var wC = yC * strideWidth - xCCorner; + var dyOffset = dyS0 * b + dyS1 * yR + dyS2 * yC; + var fltOffset = fltS0 * (filterHeight - 1 - wR) + + fltS1 * (filterWidth - 1 - wC) + fltS2 * d1; + for (var dm = 0; dm < chMul; ++dm) { + var d2 = d1 * chMul + dm; + var pixel = dyValues[dyOffset + d2]; + var weight = fltValues[fltOffset + dm]; + dotProd += pixel * weight; + } + } + } + dxValues[dxS0 * b + dxS1 * xR + dxS2 * xC + d1] = dotProd; + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.depthwiseConv2DDerFilter = function (x, dy, convInfo) { + assertNotComplex([x, dy], 'depthwiseConv2DDerFilter'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dW = buffer(convInfo.filterShape, 'float32'); + var leftPad = convInfo.padInfo.left; + var topPad = convInfo.padInfo.top; + var chMul = convInfo.outChannels / convInfo.inChannels; + var xBuf = this.bufferSync(x); + var dyBuf = this.bufferSync(dy); + for (var wR = 0; wR < filterHeight; ++wR) { + var yRMin = Math.max(0, Math.ceil((topPad - wR) / strideHeight)); + var yRMax = Math.min(convInfo.outHeight, (convInfo.inHeight + topPad - wR) / strideHeight); + for (var wC = 0; wC < filterWidth; ++wC) { + var yCMin = Math.max(0, Math.ceil((leftPad - wC) / strideWidth)); + var yCMax = Math.min(convInfo.outWidth, (convInfo.inWidth + leftPad - wC) / strideWidth); + for (var d2 = 0; d2 < convInfo.outChannels; ++d2) { + var d1 = Math.trunc(d2 / chMul); + var dm = d2 % chMul; + var dotProd = 0; + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var yR = yRMin; yR < yRMax; ++yR) { + var xR = wR + yR * strideHeight - topPad; + for (var yC = yCMin; yC < yCMax; ++yC) { + var xC = wC + yC * strideWidth - leftPad; + dotProd += xBuf.get(b, xR, xC, d1) * dyBuf.get(b, yR, yC, d2); + } + } + } + dW.set(dotProd, wR, wC, d1, dm); + } + } + } + return dW.toTensor(); + }; + MathBackendCPU.prototype.tile = function (x, reps) { + assertNotComplex(x, 'tile'); + return tile$1(this.bufferSync(x), reps); + }; + MathBackendCPU.prototype.pad = function (x, paddings, constantValue) { + assertNotComplex(x, 'pad'); + var outShape = paddings.map(function (p, i) { return p[0] /* beforePad */ + x.shape[i] + p[1]; } /* afterPad */); + var start = paddings.map(function (p) { return p[0]; }); + var xBuffer = this.bufferSync(x); + var buffer$1 = buffer(outShape, x.dtype); + if (constantValue !== 0) { + buffer$1.values.fill(constantValue); + } + for (var i = 0; i < x.size; i++) { + var coords = xBuffer.indexToLoc(i); + var outCoords = coords.map(function (c, i) { return c + start[i]; }); + buffer$1.set.apply(buffer$1, [xBuffer.get.apply(xBuffer, coords)].concat(outCoords)); + } + return buffer$1.toTensor(); + }; + MathBackendCPU.prototype.transpose = function (x, perm) { + assertNotComplex(x, 'transpose'); + var newShape = new Array(x.rank); + for (var i = 0; i < newShape.length; i++) { + newShape[i] = x.shape[perm[i]]; + } + var values = this.readSync(x.dataId); + var result = buffer(newShape, x.dtype); + var xBuf = this.bufferSync(x); + for (var i = 0; i < x.size; ++i) { + var loc = xBuf.indexToLoc(i); + // Permute location. + var newLoc = new Array(loc.length); + for (var i_1 = 0; i_1 < newLoc.length; i_1++) { + newLoc[i_1] = loc[perm[i_1]]; + } + var newIndex = result.locToIndex(newLoc); + result.values[newIndex] = values[i]; + } + return result.toTensor(); + }; + MathBackendCPU.prototype.gather = function (x, indices, axis) { + assertNotComplex([x, indices], 'gather'); + var newShape = x.shape.slice(); + var indicesValues = this.readSync(indices.dataId); + newShape[axis] = indicesValues.length; + var result = buffer(newShape, x.dtype); + var xBuf = this.bufferSync(x); + for (var i = 0; i < result.size; ++i) { + var newLoc = result.indexToLoc(i); + var originalLoc = newLoc.slice(); + originalLoc[axis] = indicesValues[newLoc[axis]]; + var originalIndex = xBuf.locToIndex(originalLoc); + result.values[i] = xBuf.values[originalIndex]; + } + return result.toTensor(); + }; + MathBackendCPU.prototype.batchToSpaceND = function (x, blockShape, crops) { + assertNotComplex([x], 'batchToSpaceND'); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + var reshaped = getReshaped(x.shape, blockShape, prod); + var permuted = getPermuted(reshaped.length, blockShape.length); + var reshapedPermuted = getReshapedPermuted(x.shape, blockShape, prod); + var sliceBeginCoords = getSliceBeginCoords(crops, blockShape.length); + var sliceSize = getSliceSize(reshapedPermuted, crops, blockShape.length); + return x.reshape(reshaped) + .transpose(permuted) + .reshape(reshapedPermuted) + .slice(sliceBeginCoords, sliceSize); + }; + MathBackendCPU.prototype.spaceToBatchND = function (x, blockShape, paddings) { + assertNotComplex([x], 'spaceToBatchND'); + var prod = blockShape.reduce(function (a, b) { return a * b; }); + var completePaddings = [[0, 0]]; + completePaddings.push.apply(completePaddings, paddings); + for (var i = 1 + blockShape.length; i < x.shape.length; ++i) { + completePaddings.push([0, 0]); + } + var paddedX = x.pad(completePaddings); + var reshapedPaddedShape = getReshaped(paddedX.shape, blockShape, prod, false); + var permutedReshapedPaddedPermutation = getPermuted(reshapedPaddedShape.length, blockShape.length, false); + var flattenShape = getReshapedPermuted(paddedX.shape, blockShape, prod, false); + return paddedX.reshape(reshapedPaddedShape) + .transpose(permutedReshapedPaddedPermutation) + .reshape(flattenShape); + }; + MathBackendCPU.prototype.pool = function (x, convInfo, poolType) { + assertNotComplex(x, 'pool'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY : + Number.POSITIVE_INFINITY); + var xValues = this.readSync(x.dataId); + var output = buffer(convInfo.outShape, x.dtype); + var outputVals = output.values; + var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * convInfo.outShape[3]; + var outputRowStrides = convInfo.outShape[2] * convInfo.outShape[3]; + var outputColStrides = convInfo.outShape[3]; + for (var b = 0; b < convInfo.batchSize; ++b) { + var outputBatchOffset = b * outputBatchStrides; + var inputBatchOffset = b * x.strides[0]; + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var xRCorner = yR * strideHeight - padTop; + var xRMin = Math.max(0, xRCorner); + var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner); + var outputRowOffset = outputBatchOffset + yR * outputRowStrides; + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var xCCorner = yC * strideWidth - padLeft; + var xCMin = Math.max(0, xCCorner); + var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner); + var minMaxValue = initialValue; + var avgValue = 0; + var count = 0; + for (var xR = xRMin; xR < xRMax; xR += dilationHeight) { + var xROffset = inputBatchOffset + xR * x.strides[1]; + for (var xC = xCMin; xC < xCMax; xC += dilationWidth) { + var xCOffset = xROffset + xC * x.strides[2]; + var pixel = xValues[xCOffset + d]; + if ((poolType === 'max' && pixel > minMaxValue)) { + minMaxValue = pixel; + } + else if (poolType === 'avg') { + avgValue += pixel; + count++; + } + } + if (isNaN(minMaxValue)) { + break; + } + } + var outputOffset = outputRowOffset + yC * outputColStrides + d; + outputVals[outputOffset] = + poolType === 'avg' ? avgValue / count : minMaxValue; + } + } + } + } + return output.toTensor(); + }; + MathBackendCPU.prototype.maxPool = function (x, convInfo) { + return this.pool(x, convInfo, 'max'); + }; + MathBackendCPU.prototype.maxPoolPositions = function (x, convInfo) { + var maxPositions = buffer(convInfo.outShape, 'int32'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var xBuf = this.bufferSync(x); + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var yR = 0; yR < convInfo.outHeight; ++yR) { + var xRCorner = yR * strideHeight - padTop; + var xRMin = xRCorner; + while (xRMin < 0) { + xRMin += dilationHeight; + } + // const xRMin = Math.max(0, xRCorner); + var xRMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRCorner); + for (var yC = 0; yC < convInfo.outWidth; ++yC) { + var xCCorner = yC * strideWidth - padLeft; + var xCMin = xCCorner; + while (xCMin < 0) { + xCMin += dilationWidth; + } + var xCMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xCCorner); + var maxValue = Number.NEGATIVE_INFINITY; + var maxPosition = -1; + for (var xR = xRMin; xR < xRMax; xR += dilationHeight) { + var wR = xR - xRCorner; + for (var xC = xCMin; xC < xCMax; xC += dilationWidth) { + var wC = xC - xCCorner; + var pixel = xBuf.get(b, xR, xC, d); + if (pixel > maxValue) { + maxValue = pixel; + maxPosition = wR * effectiveFilterWidth + wC; + } + } + } + maxPositions.set(maxPosition, b, yR, yC, d); + } + } + } + } + return maxPositions.toTensor(); + }; + MathBackendCPU.prototype.maxPoolBackprop = function (dy, x, y, convInfo) { + assertNotComplex([x, y], 'maxPoolBackprop'); + var maxPositions = this.maxPoolPositions(x, convInfo); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = buffer(x.shape, 'float32'); + var maxPosBuf = this.bufferSync(maxPositions); + var dyBuf = this.bufferSync(dy); + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) { + for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) { + // Shader code begins. + var dyRCorner = dxR - padTop; + var dyCCorner = dxC - padLeft; + var dotProd = 0; + for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) { + var dyR = (dyRCorner + wR) / strideHeight; + if (dyR < 0 || dyR >= convInfo.outHeight || + Math.floor(dyR) !== dyR) { + continue; + } + for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) { + var dyC = (dyCCorner + wC) / strideWidth; + if (dyC < 0 || dyC >= convInfo.outWidth || + Math.floor(dyC) !== dyC) { + continue; + } + var maxPos = effectiveFilterHeight * effectiveFilterWidth - + 1 - maxPosBuf.get(b, dyR, dyC, d); + var curPos = wR * effectiveFilterWidth + wC; + var mask = maxPos === curPos ? 1 : 0; + if (mask === 0) { + continue; + } + var pixel = dyBuf.get(b, dyR, dyC, d); + dotProd += pixel * mask; + } + } + dx.set(dotProd, b, dxR, dxC, d); + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.avgPoolBackprop = function (dy, x, convInfo) { + assertNotComplex([dy, x], 'avgPoolBackprop'); + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = buffer(x.shape, 'float32'); + var avgMultiplier = 1 / (filterHeight * filterWidth); + var dyBuf = this.bufferSync(dy); + for (var b = 0; b < convInfo.batchSize; ++b) { + for (var d = 0; d < convInfo.inChannels; ++d) { + for (var dxR = 0; dxR < convInfo.inHeight; ++dxR) { + for (var dxC = 0; dxC < convInfo.inWidth; ++dxC) { + // Shader code begins. + var dyRCorner = dxR - padTop; + var dyCCorner = dxC - padLeft; + var dotProd = 0; + for (var wR = 0; wR < effectiveFilterHeight; wR += dilationHeight) { + var dyR = (dyRCorner + wR) / strideHeight; + if (dyR < 0 || dyR >= convInfo.outHeight || + Math.floor(dyR) !== dyR) { + continue; + } + for (var wC = 0; wC < effectiveFilterWidth; wC += dilationWidth) { + var dyC = (dyCCorner + wC) / strideWidth; + if (dyC < 0 || dyC >= convInfo.outWidth || + Math.floor(dyC) !== dyC) { + continue; + } + var pixel = dyBuf.get(b, dyR, dyC, d); + dotProd += pixel; + } + } + dx.set(dotProd * avgMultiplier, b, dxR, dxC, d); + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.pool3d = function (x, convInfo, poolType) { + assertNotComplex(x, 'pool3d'); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = convInfo.padInfo.front; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var initialValue = (poolType === 'max' ? Number.NEGATIVE_INFINITY : + Number.POSITIVE_INFINITY); + var xValues = this.readSync(x.dataId); + var output = buffer(convInfo.outShape, x.dtype); + var outputVals = output.values; + var outputBatchStrides = convInfo.outShape[1] * convInfo.outShape[2] * + convInfo.outShape[3] * convInfo.outShape[4]; + var outputDepthStrides = convInfo.outShape[2] * convInfo.outShape[3] * convInfo.outShape[4]; + var outputRowStrides = convInfo.outShape[3] * convInfo.outShape[4]; + var outputColStrides = convInfo.outShape[4]; + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + var outputBatchOffset = batch * outputBatchStrides; + var inputBatchOffset = batch * x.strides[0]; + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) { + var xDepthCorner = yDepth * strideDepth - padFront; + var xDepthMin = xDepthCorner; + while (xDepthMin < 0) { + xDepthMin += dilationDepth; + } + var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner); + var outputDepthOffset = outputBatchOffset + yDepth * outputDepthStrides; + for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) { + var xRowCorner = yRow * strideHeight - padTop; + var xRowMin = xRowCorner; + while (xRowMin < 0) { + xRowMin += dilationHeight; + } + var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner); + var outputRowOffset = outputDepthOffset + yRow * outputRowStrides; + for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) { + var xColCorner = yCol * strideWidth - padLeft; + var xColMin = xColCorner; + while (xColMin < 0) { + xColMin += dilationWidth; + } + var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner); + // Shader code begins + var outputColOffset = outputRowOffset + yCol * outputColStrides; + var minMaxValue = initialValue; + var avgValue = 0; + var count = 0; + for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) { + var xDepthOffset = inputBatchOffset + xDepth * x.strides[1]; + for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) { + var xRowOffset = xDepthOffset + xRow * x.strides[2]; + for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) { + var xColOffset = xRowOffset + xCol * x.strides[3]; + var pixel = xValues[xColOffset + channel]; + if ((poolType === 'max' && pixel > minMaxValue)) { + minMaxValue = pixel; + } + else if (poolType === 'avg') { + avgValue += pixel; + count++; + } + if (isNaN(minMaxValue)) { + break; + } + } + if (isNaN(minMaxValue)) { + break; + } + } + if (isNaN(minMaxValue)) { + break; + } + } + var outputOffset = outputColOffset + channel; + outputVals[outputOffset] = + poolType === 'avg' ? avgValue / count : minMaxValue; + } + } + } + } + } + return output.toTensor(); + }; + MathBackendCPU.prototype.avgPool3d = function (x, convInfo) { + assertNotComplex(x, 'avgPool3d'); + return this.pool3d(x, convInfo, 'avg').toFloat(); + }; + MathBackendCPU.prototype.avgPool3dBackprop = function (dy, x, convInfo) { + assertNotComplex([dy, x], 'avgPool3dBackprop'); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var filterDepth = convInfo.filterDepth; + var filterHeight = convInfo.filterHeight; + var filterWidth = convInfo.filterWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = buffer(x.shape, 'float32'); + var avgMultiplier = 1 / (filterDepth * filterHeight * filterWidth); + var dyBuf = this.bufferSync(dy); + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) { + for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) { + for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) { + // Shader code begins. + var dyDepthCorner = dxDepth - padFront; + var dyRowCorner = dxRow - padTop; + var dyColCorner = dxCol - padLeft; + var dotProd = 0; + for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) { + var dyDepth = (dyDepthCorner + wDepth) / strideDepth; + if (dyDepth < 0 || dyDepth >= convInfo.outDepth || + Math.floor(dyDepth) !== dyDepth) { + continue; + } + for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) { + var dyRow = (dyRowCorner + wRow) / strideHeight; + if (dyRow < 0 || dyRow >= convInfo.outHeight || + Math.floor(dyRow) !== dyRow) { + continue; + } + for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) { + var dyCol = (dyColCorner + wCol) / strideWidth; + if (dyCol < 0 || dyCol >= convInfo.outWidth || + Math.floor(dyCol) !== dyCol) { + continue; + } + var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel); + dotProd += pixel; + } + } + } + dx.set(dotProd * avgMultiplier, batch, dxDepth, dxRow, dxCol, channel); + } + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.maxPool3d = function (x, convInfo) { + assertNotComplex(x, 'maxPool3d'); + return this.pool3d(x, convInfo, 'max').toFloat(); + }; + MathBackendCPU.prototype.maxPool3dPositions = function (x, convInfo) { + var maxPositions = buffer(convInfo.outShape, 'int32'); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = convInfo.padInfo.front; + var padTop = convInfo.padInfo.top; + var padLeft = convInfo.padInfo.left; + var xBuf = this.bufferSync(x); + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var yDepth = 0; yDepth < convInfo.outDepth; ++yDepth) { + var xDepthCorner = yDepth * strideDepth - padFront; + var xDepthMin = xDepthCorner; + while (xDepthMin < 0) { + xDepthMin += dilationDepth; + } + var xDepthMax = Math.min(convInfo.inDepth, effectiveFilterDepth + xDepthCorner); + for (var yRow = 0; yRow < convInfo.outHeight; ++yRow) { + var xRowCorner = yRow * strideHeight - padTop; + var xRowMin = xRowCorner; + while (xRowMin < 0) { + xRowMin += dilationHeight; + } + var xRowMax = Math.min(convInfo.inHeight, effectiveFilterHeight + xRowCorner); + for (var yCol = 0; yCol < convInfo.outWidth; ++yCol) { + var xColCorner = yCol * strideWidth - padLeft; + var xColMin = xColCorner; + while (xColMin < 0) { + xColMin += dilationWidth; + } + var xColMax = Math.min(convInfo.inWidth, effectiveFilterWidth + xColCorner); + // Shader code begins + var maxValue = Number.NEGATIVE_INFINITY; + var maxPosition = -1; + for (var xDepth = xDepthMin; xDepth < xDepthMax; xDepth += dilationDepth) { + var wDepth = xDepth - xDepthCorner; + for (var xRow = xRowMin; xRow < xRowMax; xRow += dilationHeight) { + var wRow = xRow - xRowCorner; + for (var xCol = xColMin; xCol < xColMax; xCol += dilationWidth) { + var wCol = xCol - xColCorner; + var pixel = xBuf.get(batch, xDepth, xRow, xCol, channel); + if (pixel >= maxValue) { + maxValue = pixel; + maxPosition = wDepth * effectiveFilterHeight * + effectiveFilterWidth + + wRow * effectiveFilterHeight + wCol; + } + } + } + } + maxPositions.set(maxPosition, batch, yDepth, yRow, yCol, channel); + } + } + } + } + } + return maxPositions.toTensor(); + }; + MathBackendCPU.prototype.maxPool3dBackprop = function (dy, x, y, convInfo) { + assertNotComplex([x, y], 'maxPool3dBackprop'); + var maxPositions = this.maxPool3dPositions(x, convInfo); + var strideDepth = convInfo.strideDepth; + var strideHeight = convInfo.strideHeight; + var strideWidth = convInfo.strideWidth; + var dilationDepth = convInfo.dilationDepth; + var dilationHeight = convInfo.dilationHeight; + var dilationWidth = convInfo.dilationWidth; + var effectiveFilterDepth = convInfo.effectiveFilterDepth; + var effectiveFilterHeight = convInfo.effectiveFilterHeight; + var effectiveFilterWidth = convInfo.effectiveFilterWidth; + var padFront = effectiveFilterDepth - 1 - convInfo.padInfo.front; + var padLeft = effectiveFilterWidth - 1 - convInfo.padInfo.left; + var padTop = effectiveFilterHeight - 1 - convInfo.padInfo.top; + var dx = buffer(x.shape, 'float32'); + var maxPosBuf = this.bufferSync(maxPositions); + var dyBuf = this.bufferSync(dy); + for (var batch = 0; batch < convInfo.batchSize; ++batch) { + for (var channel = 0; channel < convInfo.inChannels; ++channel) { + for (var dxDepth = 0; dxDepth < convInfo.inDepth; ++dxDepth) { + for (var dxRow = 0; dxRow < convInfo.inHeight; ++dxRow) { + for (var dxCol = 0; dxCol < convInfo.inWidth; ++dxCol) { + // Shader code begins + var dyDepthCorner = dxDepth - padFront; + var dyRowCorner = dxRow - padTop; + var dyColCorner = dxCol - padLeft; + var dotProd = 0; + for (var wDepth = 0; wDepth < effectiveFilterDepth; wDepth += dilationDepth) { + var dyDepth = (dyDepthCorner + wDepth) / strideDepth; + if (dyDepth < 0 || dyDepth >= convInfo.outDepth || + Math.floor(dyDepth) !== dyDepth) { + continue; + } + for (var wRow = 0; wRow < effectiveFilterHeight; wRow += dilationHeight) { + var dyRow = (dyRowCorner + wRow) / strideHeight; + if (dyRow < 0 || dyRow >= convInfo.outHeight || + Math.floor(dyRow) !== dyRow) { + continue; + } + for (var wCol = 0; wCol < effectiveFilterWidth; wCol += dilationWidth) { + var dyCol = (dyColCorner + wCol) / strideWidth; + if (dyCol < 0 || dyCol >= convInfo.outWidth || + Math.floor(dyCol) !== dyCol) { + continue; + } + var maxPos = effectiveFilterDepth * + effectiveFilterHeight * effectiveFilterWidth - + 1 - + maxPosBuf.get(batch, dyDepth, dyRow, dyCol, channel); + var curPos = wDepth * effectiveFilterHeight * effectiveFilterWidth + + wRow * effectiveFilterWidth + wCol; + var mask = maxPos === curPos ? 1 : 0; + if (mask === 0) { + continue; + } + var pixel = dyBuf.get(batch, dyDepth, dyRow, dyCol, channel); + dotProd += pixel * mask; + } + } + } + dx.set(dotProd, batch, dxDepth, dxRow, dxCol, channel); + } + } + } + } + } + return dx.toTensor(); + }; + MathBackendCPU.prototype.cast = function (x, dtype) { + return castTensor(x, dtype, this); + }; + MathBackendCPU.prototype.reshape = function (x, shape) { + return reshapeTensor(x, shape); + }; + MathBackendCPU.prototype.avgPool = function (x, convInfo) { + assertNotComplex(x, 'avgPool'); + return this.pool(x, convInfo, 'avg').toFloat(); + }; + MathBackendCPU.prototype.resizeBilinear = function (x, newHeight, newWidth, alignCorners) { + assertNotComplex(x, 'resizeBilinear'); + var _a = x.shape, batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], numChannels = _a[3]; + var xValues = this.readSync(x.dataId); + var result = new Float32Array(sizeFromShape([batch, newHeight, newWidth, numChannels])); + var effectiveInputSize = [ + (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, + (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth + ]; + var effectiveOutputSize = [ + (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, + (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth + ]; + var outputIdx = 0; + var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0]; + var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1]; + for (var b = 0; b < batch; b++) { + for (var r = 0; r < newHeight; r++) { + var sourceFracRow = effectiveRowSizeRatio * r; + var sourceRowFloor = Math.floor(sourceFracRow); + var rowFrac = sourceFracRow - sourceRowFloor; + var sourceRowCeil = Math.min(oldHeight - 1, Math.ceil(sourceFracRow)); + var topRowOffset = b * x.strides[0] + sourceRowFloor * x.strides[1]; + var botRowOffset = b * x.strides[0] + sourceRowCeil * x.strides[1]; + for (var c = 0; c < newWidth; c++) { + var sourceFracCol = effectiveColSizeRatio * c; + var sourceColFloor = Math.floor(sourceFracCol); + var colFrac = sourceFracCol - sourceColFloor; + var sourceColCeil = Math.min(oldWidth - 1, Math.ceil(sourceFracCol)); + var topLeftOffest = topRowOffset + sourceColFloor * x.strides[2]; + var botLeftOffset = botRowOffset + sourceColFloor * x.strides[2]; + var topRightOffset = topRowOffset + +sourceColCeil * x.strides[2]; + var botRightOffest = botRowOffset + sourceColCeil * x.strides[2]; + for (var d = 0; d < numChannels; d++) { + // Begin shader. + // Compute the fractional index of the source. + var topLeft = xValues[topLeftOffest + d]; + var bottomLeft = xValues[botLeftOffset + d]; + var topRight = xValues[topRightOffset + d]; + var bottomRight = xValues[botRightOffest + d]; + var top_1 = topLeft + (topRight - topLeft) * colFrac; + var bottom = bottomLeft + (bottomRight - bottomLeft) * colFrac; + var newValue = top_1 + (bottom - top_1) * rowFrac; + result[outputIdx++] = newValue; + } + } + } + } + return tensor(result, [batch, newHeight, newWidth, numChannels]); + }; + MathBackendCPU.prototype.resizeBilinearBackprop = function (dy, x, alignCorners) { + assertNotComplex([dy, x], 'resizeBilinearBackprop'); + var _a = x.shape, batch = _a[0], xHeight = _a[1], xWidth = _a[2], depth = _a[3]; + var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; + var output = new Float32Array(batch * xHeight * xWidth * depth); + // In the backwards pass, we want to find the pixels that were generated + // for each pixel in the input image the forward pass and add the + // corresponding coefficient from dy to the gradient (with some + // interpolation). + var effectiveXSize = [ + (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, + (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth + ]; + var effectiveYSize = [ + (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, + (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth + ]; + var heightScale = effectiveXSize[0] / effectiveYSize[0]; + var widthScale = effectiveXSize[1] / effectiveYSize[1]; + // Reference implementation + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/3039375c86a5bbc9610c7725dcaa95d635f87ba2/tensorflow/core/kernels/resize_bilinear_op.cc#L275 + var dyValues = this.readSync(dy.dataId); + var offset = 0; + for (var b = 0; b < batch; b++) { + var bOffset = b * x.strides[0]; + for (var r = 0; r < yHeight; r++) { + var dxR = r * heightScale; + var topDxRIndex = Math.floor(dxR); + var bottomDxRIndex = Math.min(Math.ceil(dxR), xHeight - 1); + var topDxROffset = bOffset + topDxRIndex * x.strides[1]; + var bottomDxROffset = bOffset + bottomDxRIndex * x.strides[1]; + var dxRLerp = dxR - topDxRIndex; + var inverseDxRLerp = 1.0 - dxRLerp; + for (var c = 0; c < yWidth; c++) { + var dxC = c * widthScale; + var leftDxCIndex = Math.floor(dxC); + var rightDxCIndex = Math.min(Math.ceil(dxC), xWidth - 1); + var dxCLerp = dxC - leftDxCIndex; + var inverseDxCLerp = 1.0 - dxCLerp; + var topLeftRCOffset = topDxROffset + leftDxCIndex * x.strides[2]; + var topRightRCOffset = topDxROffset + rightDxCIndex * x.strides[2]; + var bottomLeftRCOffset = bottomDxROffset + leftDxCIndex * x.strides[2]; + var bottomRightRCOffset = bottomDxROffset + rightDxCIndex * x.strides[2]; + var inverseDxRLerpTimesInverseDxCLerp = inverseDxRLerp * inverseDxCLerp; + var inverseDxRLerpTimesDxCLerp = inverseDxRLerp * dxCLerp; + var dxRLerpTimesInverseDxCLerp = dxRLerp * inverseDxCLerp; + var dxRLerpTimesDxCLerp = dxRLerp * dxCLerp; + for (var d = 0; d < depth; d++) { + var dyVal = dyValues[offset++]; + output[topLeftRCOffset + d] += + dyVal * inverseDxRLerpTimesInverseDxCLerp; + output[topRightRCOffset + d] += dyVal * inverseDxRLerpTimesDxCLerp; + output[bottomLeftRCOffset + d] += + dyVal * dxRLerpTimesInverseDxCLerp; + output[bottomRightRCOffset + d] += dyVal * dxRLerpTimesDxCLerp; + } + } + } + } + return tensor4d(output, [batch, xWidth, xHeight, depth], x.dtype); + }; + MathBackendCPU.prototype.resizeNearestNeighbor = function (x, newHeight, newWidth, alignCorners) { + assertNotComplex(x, 'resizeNearestNeighbor'); + var _a = x.shape, batch = _a[0], oldHeight = _a[1], oldWidth = _a[2], numChannels = _a[3]; + var xValues = this.readSync(x.dataId); + var output = new Float32Array(batch * newHeight * newWidth * numChannels); + var effectiveInputSize = [ + (alignCorners && newHeight > 1) ? oldHeight - 1 : oldHeight, + (alignCorners && newWidth > 1) ? oldWidth - 1 : oldWidth + ]; + var effectiveOutputSize = [ + (alignCorners && newHeight > 1) ? newHeight - 1 : newHeight, + (alignCorners && newWidth > 1) ? newWidth - 1 : newWidth + ]; + var effectiveRowSizeRatio = effectiveInputSize[0] / effectiveOutputSize[0]; + var effectiveColSizeRatio = effectiveInputSize[1] / effectiveOutputSize[1]; + var outputOffset = 0; + for (var b = 0; b < batch; b++) { + var batchOffset = b * x.strides[0]; + for (var r = 0; r < newHeight; r++) { + var sourceFracRow = effectiveRowSizeRatio * r; + var sourceNearestRow = Math.min(oldHeight - 1, alignCorners ? Math.round(sourceFracRow) : + Math.floor(sourceFracRow)); + var rowOffset = batchOffset + sourceNearestRow * x.strides[1]; + for (var c = 0; c < newWidth; c++) { + var sourceFracCol = effectiveColSizeRatio * c; + var sourceNearestCol = Math.min(oldWidth - 1, alignCorners ? Math.round(sourceFracCol) : + Math.floor(sourceFracCol)); + var colOffset = rowOffset + sourceNearestCol * x.strides[2]; + for (var d = 0; d < numChannels; d++) { + // Begin shader. + // Compute the fractional index of the source. + var newVal = xValues[colOffset + d]; + output[outputOffset++] = newVal; + } + } + } + } + return tensor(output, [batch, newHeight, newWidth, numChannels], x.dtype); + }; + MathBackendCPU.prototype.resizeNearestNeighborBackprop = function (dy, x, alignCorners) { + assertNotComplex([dy, x], 'resizeNearestNeighborBackprop'); + var _a = x.shape, batch = _a[0], xHeight = _a[1], xWidth = _a[2], depth = _a[3]; + var _b = dy.shape, yHeight = _b[1], yWidth = _b[2]; + var output = new Float32Array(batch * xHeight * xWidth * depth); + var dyValues = this.readSync(dy.dataId); + // In the backwards pass, we want to find the pixels that were generated + // for each pixel in the input image the forward pass + var effectiveXSize = [ + (alignCorners && yHeight > 1) ? xHeight - 1 : xHeight, + (alignCorners && yWidth > 1) ? xWidth - 1 : xWidth + ]; + var effectiveYSize = [ + (alignCorners && yHeight > 1) ? yHeight - 1 : yHeight, + (alignCorners && yWidth > 1) ? yWidth - 1 : yWidth + ]; + var heightScale = effectiveXSize[0] / effectiveYSize[0]; + var widthScale = effectiveXSize[1] / effectiveYSize[1]; + var invHeightScale = 1 / heightScale; + var invWidthScale = 1 / widthScale; + // This defines the size of the window of values around a particular + // index in dy that we want to search for contributions to dx. + var winHeight = (Math.ceil(invHeightScale) * 2) + 2; + var winWidth = (Math.ceil(invWidthScale) * 2) + 2; + // Loop over the output space. + for (var b = 0; b < batch; b++) { + var batchOffset = b * x.strides[0]; + for (var r = 0; r < xHeight; r++) { + var rowOffset = batchOffset + r * x.strides[1]; + // Compute bounds for where in dy we will look + var startRLerp = Math.floor(r * invHeightScale); + var startDyR = Math.floor(startRLerp - (winHeight / 2)); + for (var c = 0; c < xWidth; c++) { + var colOffset = rowOffset + c * x.strides[2]; + // Compute bounds for where in dy we will look + var startCLerp = Math.floor(c * invWidthScale); + var startDyC = Math.floor(startCLerp - (winWidth / 2)); + for (var d = 0; d < depth; d++) { + var accum = 0; + // loop over dy + for (var dyRIndex = 0; dyRIndex < winHeight; dyRIndex++) { + var dyR = dyRIndex + startDyR; + // Guard against the window exceeding the bounds of dy + if (dyR < 0 || dyR >= yHeight) { + continue; + } + var dyROffset = batchOffset + dyR * dy.strides[1]; + var sourceFracRow = dyR * heightScale; + var sourceNearestRow = Math.min(xHeight - 1, alignCorners ? Math.round(sourceFracRow) : + Math.floor(sourceFracRow)); + if (r !== sourceNearestRow) { + continue; + } + for (var dyCIndex = 0; dyCIndex < winWidth; dyCIndex++) { + var dyC = dyCIndex + startDyC; + // Guard against the window exceeding the bounds of dy + if (dyC < 0 || dyC >= yWidth) { + continue; + } + var dyCOffset = dyROffset + dyC * dy.strides[2]; + var sourceFracCol = dyC * widthScale; + var sourceNearestCol = Math.min(xWidth - 1, alignCorners ? Math.round(sourceFracCol) : + Math.floor(sourceFracCol)); + if (c === sourceNearestCol) { + accum += dyValues[dyCOffset + d]; + } + } + } + output[colOffset + d] = accum; + } + } + } + } + return tensor4d(output, x.shape, x.dtype); + }; + MathBackendCPU.prototype.batchNormalization = function (x, mean, variance, varianceEpsilon, scale, offset) { + assertNotComplex([x, mean, variance, scale, offset], 'batchNorm'); + var xVals = this.readSync(x.dataId); + var mVals = this.readSync(mean.dataId); + var varVals = this.readSync(variance.dataId); + var sVals = scale ? this.readSync(scale.dataId) : + new Float32Array([1]); + var offVals = offset ? this.readSync(offset.dataId) : + new Float32Array([0]); + var outVals = new Float32Array(xVals.length); + var offValsLength = offVals.length; + var sValsLength = sVals.length; + var varValsLength = varVals.length; + var mValsLength = mVals.length; + var offi = 0; + var mi = 0; + var si = 0; + var vi = 0; + for (var i = 0; i < xVals.length; ++i) { + outVals[i] = offVals[offi++] + + (xVals[i] - mVals[mi++]) * sVals[si++] / + Math.sqrt(varVals[vi++] + varianceEpsilon); + if (offi >= offValsLength) { + offi = 0; + } + if (mi >= mValsLength) { + mi = 0; + } + if (si >= sValsLength) { + si = 0; + } + if (vi >= varValsLength) { + vi = 0; + } + } + return tensor4d(outVals, x.shape); + }; + MathBackendCPU.prototype.localResponseNormalization4D = function (x, depthRadius, bias, alpha, beta) { + assertNotComplex(x, 'localResponseNormalization4D'); + var channels = x.shape[3]; + var maxD = channels - 1; + var xValues = this.readSync(x.dataId); + var size = x.size; + var result = new Float32Array(size); + function sumAcrossChannels(offset) { + var currentChannel = offset % channels; + var beginSumOffset = offset - currentChannel + Math.max(0, currentChannel - depthRadius); + var endSumOffset = offset - currentChannel + + Math.min(currentChannel + depthRadius, maxD); + var sum = 0.0; + for (; beginSumOffset <= endSumOffset; beginSumOffset++) { + var z = xValues[beginSumOffset]; + sum += z * z; + } + return sum; + } + for (var offset = 0; offset < size; offset++) { + var sum = sumAcrossChannels(offset); + var val = xValues[offset] * Math.pow(bias + alpha * sum, -beta); + result[offset] = val; + } + return tensor4d(result, x.shape); + }; + MathBackendCPU.prototype.LRNGrad = function (dy, inputImage, outputImage, depthRadius, bias, alpha, beta) { + assertNotComplex(dy, 'LRNGrad'); + var channels = dy.shape[3]; + var dyValues = this.readSync(dy.dataId); + var inputImageValues = this.readSync(inputImage.dataId); + var outputImageValues = this.readSync(outputImage.dataId); + var result = new Float32Array(dy.size); + var size = dy.size; + for (var offset = 0; offset < size; offset++) { + var currentChannel = offset % channels; + var depthBegin = (offset - currentChannel) + Math.max(0, currentChannel - depthRadius); + var depthEnd = (offset - currentChannel) + + Math.min(channels, currentChannel + depthRadius + 1); + var norm = 0; + for (var k = depthBegin; k < depthEnd; k++) { + norm += Math.pow(inputImageValues[k], 2); + } + norm = alpha * norm + bias; + for (var k = depthBegin; k < depthEnd; k++) { + var dyi = -2 * alpha * beta * inputImageValues[k] * + outputImageValues[offset] / norm; + if (offset === k) { + dyi += Math.pow(norm, -beta); + } + dyi *= dyValues[offset]; + result[k] += dyi; + } + } + return tensor4d(result, dy.shape); + }; + MathBackendCPU.prototype.multinomial = function (logits, normalized, numSamples, seed) { + assertNotComplex(logits, 'multinomial'); + var probabilities = normalized ? logits : softmax(logits); + var batchSize = probabilities.shape[0]; + var numEvents = probabilities.shape[1]; + var res = zeros([batchSize, numSamples], 'int32'); + var resVals = this.readSync(res.dataId); + var probVals = this.readSync(probabilities.dataId); + for (var b = 0; b < batchSize; ++b) { + var offset = b * numEvents; + // The cdf won't include the last event. It will be implicit if no other + // event happened. + var cdf = new Float32Array(numEvents - 1); + cdf[0] = probVals[offset]; + for (var event_1 = 1; event_1 < cdf.length; ++event_1) { + cdf[event_1] = cdf[event_1 - 1] + probVals[offset + event_1]; + } + var random = seedrandom_1(seed.toString()); + var outOffset = b * numSamples; + for (var sampleId = 0; sampleId < numSamples; ++sampleId) { + var r = random(); + // Assume last event happened by default. + resVals[outOffset + sampleId] = cdf.length; + for (var event_2 = 0; event_2 < cdf.length; event_2++) { + if (r < cdf[event_2]) { + resVals[outOffset + sampleId] = event_2; + break; + } + } + } + } + return res; + }; + MathBackendCPU.prototype.oneHot = function (indices, depth, onValue, offValue) { + assertNotComplex(indices, 'oneHot'); + var res = new Float32Array(indices.size * depth); + res.fill(offValue); + var indicesVal = this.readSync(indices.dataId); + for (var event_3 = 0; event_3 < indices.size; ++event_3) { + if (indicesVal[event_3] >= 0 && indicesVal[event_3] < depth) { + res[event_3 * depth + indicesVal[event_3]] = onValue; + } + } + return tensor2d(res, [indices.size, depth], 'int32'); + }; + MathBackendCPU.prototype.nonMaxSuppression = function (boxes, scores, maxOutputSize, iouThreshold, scoreThreshold) { + assertNotComplex(boxes, 'nonMaxSuppression'); + var boxesVals = this.readSync(boxes.dataId); + var scoresVals = this.readSync(scores.dataId); + return nonMaxSuppressionImpl(boxesVals, scoresVals, maxOutputSize, iouThreshold, scoreThreshold); + }; + MathBackendCPU.prototype.fft = function (x) { + return this.fftBatch(x, false); + }; + MathBackendCPU.prototype.ifft = function (x) { + return this.fftBatch(x, true); + }; + /** + * Calculate FFT of inner most elements of batch tensor. + */ + MathBackendCPU.prototype.fftBatch = function (x, inverse) { + var batch = x.shape[0]; + var innerDim = x.shape[1]; + // Collects real and imaginary values separately. + var realResult = buffer(x.shape, 'float32'); + var imagResult = buffer(x.shape, 'float32'); + var real$1 = real(x).as2D(batch, innerDim); + var imag$1 = imag(x).as2D(batch, innerDim); + for (var b = 0; b < batch; b++) { + // TODO: Support slice ops for complex type. + var r = real$1.slice([b, 0], [1, innerDim]); + var i = imag$1.slice([b, 0], [1, innerDim]); + var input = complex(r, i); + // Run FFT by batch element. + var res = this.readSync(this.fftImpl(input, inverse).dataId); + for (var d = 0; d < innerDim; d++) { + var c = getComplexWithIndex(res, d); + realResult.values[b * innerDim + d] = c.real; + imagResult.values[b * innerDim + d] = c.imag; + } + } + var t = complex(realResult.toTensor(), imagResult.toTensor()); + return t.as2D(batch, innerDim); + }; + MathBackendCPU.prototype.fftImpl = function (x, inverse) { + var x1D = x.as1D(); + var n = x1D.size; + if (this.isExponentOf2(n)) { + var result = this.fftRadix2(x1D, n, inverse).as2D(x.shape[0], x.shape[1]); + if (inverse) { + result = complex(real(result).div(scalar(n)), imag(result).div(scalar(n))); + } + return result; + } + else { + var data = this.readSync(x.dataId); + var rawOutput = this.fourierTransformByMatmul(data, n, inverse); + var output = splitRealAndImagArrays(rawOutput); + return complex(output.real, output.imag).as2D(x.shape[0], x.shape[1]); + } + }; + MathBackendCPU.prototype.isExponentOf2 = function (size) { + return (size & size - 1) === 0; + }; + // FFT using Cooley-Tukey algorithm on radix 2 dimensional input. + MathBackendCPU.prototype.fftRadix2 = function (input, size, inverse) { + if (size === 1) { + return input; + } + var data = this.readSync(input.dataId); + var half = size / 2; + var evenComplex = complexWithEvenIndex(data); + var evenTensor = complex(evenComplex.real, evenComplex.imag).as1D(); + var oddComplex = complexWithOddIndex(data); + var oddTensor = complex(oddComplex.real, oddComplex.imag).as1D(); + // Recursive call for half part of original input. + evenTensor = this.fftRadix2(evenTensor, half, inverse); + oddTensor = this.fftRadix2(oddTensor, half, inverse); + var e = exponents(size, inverse); + var exponent = complex(e.real, e.imag).mul(oddTensor); + var addPart = evenTensor.add(exponent); + var subPart = evenTensor.sub(exponent); + var realTensor = real(addPart).concat(real(subPart)); + var imagTensor = imag(addPart).concat(imag(subPart)); + return complex(realTensor, imagTensor).as1D(); + }; + // Calculate fourier transform by multplying sinusoid matrix. + MathBackendCPU.prototype.fourierTransformByMatmul = function (data, size, inverse) { + var ret = new Float32Array(size * 2); + // TODO: Use matmul instead once it supports complex64 type. + for (var r = 0; r < size; r++) { + var real_2 = 0.0; + var imag_2 = 0.0; + for (var c = 0; c < size; c++) { + var e = exponent(r * c, size, inverse); + var term = getComplexWithIndex(data, c); + real_2 += term.real * e.real - term.imag * e.imag; + imag_2 += term.real * e.imag + term.imag * e.real; + } + if (inverse) { + real_2 /= size; + imag_2 /= size; + } + assignToTypedArray(ret, real_2, imag_2, r); + } + return ret; + }; + MathBackendCPU.prototype.depthToSpace = function (x, blockSize, dataFormat) { + assert(dataFormat === 'NHWC', function () { return "Only NHWC dataFormat supported on CPU for depthToSpace. Got " + dataFormat; }); + assert(blockSize > 1, function () { + return "blockSize should be > 1 for depthToSpace, but was: " + blockSize; + }); + var batchSize = x.shape[0]; + var inputHeight = x.shape[1]; + var inputWidth = x.shape[2]; + var inputDepth = x.shape[3]; + var outputHeight = inputHeight * blockSize; + var outputWidth = inputWidth * blockSize; + var outputDepth = inputDepth / (blockSize * blockSize); + var xValues = this.readSync(x.dataId); + var result = new Float32Array(batchSize * outputHeight * outputWidth * outputDepth); + var outputIdx = 0; + for (var b = 0; b < batchSize; ++b) { + for (var h = 0; h < outputHeight; ++h) { + var inH = Math.floor(h / blockSize); + var offsetH = (h % blockSize); + for (var w = 0; w < outputWidth; ++w) { + var inW = Math.floor(w / blockSize); + var offsetW = (w % blockSize); + var offsetD = (offsetH * blockSize + offsetW) * outputDepth; + for (var d = 0; d < outputDepth; ++d) { + var inD = d + offsetD; + var inputIdx = inD + inputDepth * (inW + inputWidth * (inH + inputHeight * b)); + result[outputIdx++] = xValues[inputIdx]; + } + } + } + } + return tensor4d(result, [batchSize, outputHeight, outputWidth, outputDepth]); + }; + MathBackendCPU.prototype.broadcastedBinaryOp = function (a, b, dtype, op) { + var newShape = assertAndGetBroadcastShape(a.shape, b.shape); + var result = buffer(newShape, dtype); + var aVals = this.readSync(a.dataId); + var bVals = this.readSync(b.dataId); + var aBroadcastDims = getBroadcastDims(a.shape, newShape); + var bBroadcastDims = getBroadcastDims(b.shape, newShape); + var resVals = result.values; + if (aBroadcastDims.length + bBroadcastDims.length === 0) { + for (var i = 0; i < resVals.length; ++i) { + resVals[i] = op(aVals[i % aVals.length], bVals[i % bVals.length]); + } + } + else { + var aBuf = this.bufferSync(a); + var bBuf = this.bufferSync(b); + var _loop_2 = function (i) { + var loc = result.indexToLoc(i); + var aLoc = loc.slice(-a.rank); + aBroadcastDims.forEach(function (d) { return aLoc[d] = 0; }); + var aIndex = aBuf.locToIndex(aLoc); + var bLoc = loc.slice(-b.rank); + bBroadcastDims.forEach(function (d) { return bLoc[d] = 0; }); + var bIndex = bBuf.locToIndex(bLoc); + resVals[i] = op(aVals[aIndex], bVals[bIndex]); + }; + for (var i = 0; i < resVals.length; ++i) { + _loop_2(i); + } + } + return result.toTensor(); + }; + MathBackendCPU.prototype.broadcastedBinaryComplexOp = function (a, b, op) { + var newShape = assertAndGetBroadcastShape(a.shape, b.shape); + var realResult = buffer(newShape, 'float32'); + var imagResult = buffer(newShape, 'float32'); + var aVals = this.readSync(a.dataId); + var bVals = this.readSync(b.dataId); + var aBroadcastDims = getBroadcastDims(a.shape, newShape); + var bBroadcastDims = getBroadcastDims(b.shape, newShape); + var realVals = realResult.values; + var imagVals = imagResult.values; + if (aBroadcastDims.length + bBroadcastDims.length === 0) { + for (var i = 0; i < realVals.length; i++) { + var aIdx = i % aVals.length; + var bIdx = i % bVals.length; + var result = op(aVals[aIdx * 2], aVals[aIdx * 2 + 1], bVals[bIdx * 2], bVals[bIdx * 2 + 1]); + realVals[i] = result.real; + imagVals[i] = result.imag; + } + } + else { + var aRealBuf = this.bufferSync(this.data.get(a.dataId).complexTensors.real); + var bRealBuf = this.bufferSync(this.data.get(b.dataId).complexTensors.real); + var _loop_3 = function (i) { + var loc = realResult.indexToLoc(i); + var aLoc = loc.slice(-a.rank); + aBroadcastDims.forEach(function (d) { return aLoc[d] = 0; }); + var aIndex = aRealBuf.locToIndex(aLoc); + var bLoc = loc.slice(-b.rank); + bBroadcastDims.forEach(function (d) { return bLoc[d] = 0; }); + var bIndex = bRealBuf.locToIndex(bLoc); + var opResult = op(aVals[aIndex * 2], aVals[aIndex * 2 + 1], bVals[bIndex * 2], bVals[bIndex * 2 + 1]); + realVals[i] = opResult.real; + imagVals[i] = opResult.imag; + }; + for (var i = 0; i < realVals.length; i++) { + _loop_3(i); + } + } + return this.complex(realResult.toTensor(), imagResult.toTensor()); + }; + MathBackendCPU.prototype.split = function (x, sizeSplits, axis) { + return split$1(x, sizeSplits, axis); + }; + MathBackendCPU.prototype.dispose = function () { }; + MathBackendCPU.prototype.floatPrecision = function () { + return 32; + }; + /** Returns the smallest representable number. */ + MathBackendCPU.prototype.epsilon = function () { + return EPSILON_FLOAT32; + }; + MathBackendCPU.prototype.cropAndResize = function (images, boxes, boxIndex, cropSize, method, extrapolationValue) { + var _a = images.shape, batch = _a[0], imageHeight = _a[1], imageWidth = _a[2], numChannels = _a[3]; + var numBoxes = boxes.shape[0]; + var cropHeight = cropSize[0], cropWidth = cropSize[1]; + var output = buffer([numBoxes, cropHeight, cropWidth, numChannels], images.dtype); + var boxVals = this.readSync(boxes.dataId); + var boxIndVals = this.readSync(boxIndex.dataId); + var imageVals = this.readSync(images.dataId); + var inStride = images.strides; // to calculate flat indexes into image + var outStride = output.strides; // to calculate flat indexes into output + // Reference implementation + // tslint:disable-next-line:max-line-length + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/crop_and_resize_op.cc + for (var b = 0; b < numBoxes; b++) { + var startInd = b * 4; + var y1 = boxVals[startInd]; + var x1 = boxVals[startInd + 1]; + var y2 = boxVals[startInd + 2]; + var x2 = boxVals[startInd + 3]; + var bInd = boxIndVals[b]; + if (bInd >= batch) { + continue; + } + var heightScale = (cropHeight > 1) ? + (y2 - y1) * (imageHeight - 1) / (cropHeight - 1) : + 0; + var widthScale = (cropWidth > 1) ? (x2 - x1) * (imageWidth - 1) / (cropWidth - 1) : 0; + for (var y = 0; y < cropHeight; y++) { + var yInd = (cropHeight > 1) ? + y1 * (imageHeight - 1) + y * (heightScale) : + 0.5 * (y1 + y2) * (imageHeight - 1); + if (yInd < 0 || yInd > imageHeight - 1) { + for (var x = 0; x < cropWidth; x++) { + for (var c = 0; c < numChannels; c++) { + var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = extrapolationValue; + } + } + continue; + } + if (method === 'bilinear') { + var topInd = Math.floor(yInd); + var bottomInd = Math.ceil(yInd); + var yLerp = yInd - topInd; + for (var x = 0; x < cropWidth; x++) { + var xInd = (cropWidth > 1) ? + x1 * (imageWidth - 1) + x * widthScale : + 0.5 * (x1 + x2) * (imageWidth - 1); + if (xInd < 0 || xInd > imageWidth - 1) { + for (var c = 0; c < numChannels; c++) { + var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = extrapolationValue; + } + continue; + } + var leftInd = Math.floor(xInd); + var rightInd = Math.ceil(xInd); + var xLerp = xInd - leftInd; + for (var c = 0; c < numChannels; c++) { + var ind = c + leftInd * inStride[2] + topInd * inStride[1] + + bInd * inStride[0]; + var topLeft = imageVals[ind]; + ind = c + rightInd * inStride[2] + topInd * inStride[1] + + bInd * inStride[0]; + var topRight = imageVals[ind]; + ind = c + leftInd * inStride[2] + bottomInd * inStride[1] + + bInd * inStride[0]; + var bottomLeft = imageVals[ind]; + ind = c + rightInd * inStride[2] + bottomInd * inStride[1] + + bInd * inStride[0]; + var bottomRight = imageVals[ind]; + var top_2 = topLeft + (topRight - topLeft) * xLerp; + var bottom = bottomLeft + (bottomRight - bottomLeft) * xLerp; + ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = top_2 + ((bottom - top_2) * yLerp); + } + } + } + else { // method == "nearest" + for (var x = 0; x < cropWidth; ++x) { + var xInd = (cropWidth > 1) ? + x1 * (imageWidth - 1) + x * widthScale : + 0.5 * (x1 + x2) * (imageWidth - 1); + if (xInd < 0 || xInd > imageWidth - 1) { + for (var c = 0; c < numChannels; c++) { + var ind = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[ind] = extrapolationValue; + } + continue; + } + var closestX = Math.round(xInd); + var closestY = Math.round(yInd); + for (var c = 0; c < numChannels; c++) { + var inInd = c + closestX * inStride[2] + + closestY * inStride[1] + bInd * inStride[0]; + var outInd = c + x * outStride[2] + y * outStride[1] + b * outStride[0]; + output.values[outInd] = imageVals[inInd]; + } + } + } + } + } + return output.toTensor(); + }; + MathBackendCPU.prototype.sparseToDense = function (sparseIndices, sparseValues, outputShape, defaultValue) { + var _a = calculateShapes(sparseValues, sparseIndices, outputShape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize; + var sumDupeIndices = false; + return this.scatter(sparseIndices, sparseValues, outputShape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices); + }; + MathBackendCPU.prototype.gatherND = function (x, indices) { + var indicesShape = indices.shape; + var sliceRank = indicesShape[indicesShape.length - 1]; + var _a = prepareAndValidate(x, indices), resultShape = _a[0], numSlices = _a[1], sliceSize = _a[2], strides = _a[3]; + if (numSlices === 0) { + return tensor([], resultShape, x.dtype); + } + var buffer = new TensorBuffer([numSlices, sliceSize], x.dtype); + var indicesData = this.readSync(indices.dataId); + var xData = this.readSync(x.dataId); + for (var i = 0; i < numSlices; i++) { + var index = []; + var flattenIndex = 0; + for (var j = 0; j < sliceRank; j++) { + var dim = indicesData[i * sliceRank + j]; + flattenIndex += dim * strides[j]; + index.push(dim); + } + if (flattenIndex < 0 || flattenIndex >= x.size / sliceSize) { + throw new Error("Invalid indices: " + index + " does not index into " + x.shape); + } + for (var k = 0; k < sliceSize; k++) { + buffer.values[i * sliceSize + k] = xData[flattenIndex * sliceSize + k]; + } + } + return buffer.toTensor().reshape(resultShape); + }; + MathBackendCPU.prototype.scatterND = function (indices, updates, shape) { + var _a = calculateShapes(updates, indices, shape), sliceRank = _a.sliceRank, numUpdates = _a.numUpdates, sliceSize = _a.sliceSize, strides = _a.strides, outputSize = _a.outputSize; + var defaultValue = scalar(0); + var sumDupeIndices = true; + return this.scatter(indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices); + }; + MathBackendCPU.prototype.fill = function (shape, value, dtype) { + dtype = dtype || inferDtype(value); + var values = getArrayFromDType(dtype, sizeFromShape(shape)); + values.fill(value); + return ENGINE.makeTensor(values, shape, dtype, this); + }; + MathBackendCPU.prototype.onesLike = function (x) { + if (x.dtype === 'string') { + throw new Error('onesLike is not supported for string tensors'); + } + else { + return this.fill(x.shape, 1, x.dtype); + } + }; + MathBackendCPU.prototype.zerosLike = function (x) { + var values = getArrayFromDType(x.dtype, sizeFromShape(x.shape)); + return this.makeOutput(values, x.shape, x.dtype); + }; + MathBackendCPU.prototype.linspace = function (start, stop, num) { + return linspaceImpl(start, stop, num); + }; + MathBackendCPU.prototype.scatter = function (indices, updates, shape, outputSize, sliceSize, numUpdates, sliceRank, strides, defaultValue, sumDupeIndices) { + var flattenShape = [outputSize / sliceSize, sliceSize]; + var indicesData = this.readSync(indices.dataId); + var updatesData = this.readSync(updates.dataId); + if (outputSize === 0) { + return tensor([], shape, updates.dtype); + } + var buffer = new TensorBuffer(flattenShape, updates.dtype); + buffer.values.fill(this.readSync(defaultValue.dataId)[0]); + for (var i = 0; i < numUpdates; i++) { + var index = []; + var flattenIndex = 0; + for (var j = 0; j < sliceRank; j++) { + var dim = indicesData[i * sliceRank + j]; + index.push(dim); + flattenIndex += dim * strides[j]; + } + if (flattenIndex < 0 || flattenIndex >= outputSize / sliceSize) { + throw new Error("Invalid indices: " + index + " does not index into " + shape); + } + for (var k = 0; k < sliceSize; k++) { + if (sumDupeIndices) { + buffer.values[flattenIndex * sliceSize + k] += + updatesData[i * sliceSize + k]; + } + else { + buffer.values[flattenIndex * sliceSize + k] = updates.rank === 0 ? + updatesData[0] : + updatesData[i * sliceSize + k]; + } + } + } + return buffer.toTensor().reshape(shape); + }; + return MathBackendCPU; + }(KernelBackend)); + ENGINE.registerBackend('cpu', function () { return new MathBackendCPU(); }, 1 /* priority */); + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerKernel({ + kernelName: 'Square', + backendName: 'cpu', + kernelFunc: function (_a) { + var inputs = _a.inputs, backend = _a.backend; + var x = inputs.x; + var cpuBackend = backend; + assertNotComplex(x, 'square'); + var values = cpuBackend.data.get(x.dataId).values; + var newValues = new Float32Array(values.length); + for (var i = 0; i < values.length; ++i) { + var value = values[i]; + newValues[i] = value * value; + } + var dataId = cpuBackend.write(newValues, x.shape, x.dtype); + return { dataId: dataId, shape: x.shape, dtype: x.dtype }; + } + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + registerKernel({ + kernelName: 'Square', + backendName: 'webgl', + kernelFunc: function (_a) { + var inputs = _a.inputs, backend = _a.backend; + var x = inputs.x; + var webglBackend = backend; + var program = new UnaryOpProgram(x.shape, SQUARE); + return webglBackend.runWebGLProgram(program, [x], x.dtype); + } + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PlatformBrowser = /** @class */ (function () { + function PlatformBrowser() { + } + PlatformBrowser.prototype.fetch = function (path, init) { + return fetch(path, init); + }; + PlatformBrowser.prototype.now = function () { + return performance.now(); + }; + PlatformBrowser.prototype.encode = function (text, encoding) { + if (encoding !== 'utf-8' && encoding !== 'utf8') { + throw new Error("Browser's encoder only supports utf-8, but got " + encoding); + } + if (this.textEncoder == null) { + this.textEncoder = new TextEncoder(); + } + return this.textEncoder.encode(text); + }; + PlatformBrowser.prototype.decode = function (bytes, encoding) { + return new TextDecoder(encoding).decode(bytes); + }; + return PlatformBrowser; + }()); + if (env().get('IS_BROWSER')) { + env().setPlatform('browser', new PlatformBrowser()); + } + + /** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + // We are wrapping this within an object so it can be stubbed by Jasmine. + var getNodeFetch = { + // tslint:disable-next-line:no-require-imports + importFetch: function () { return require('node-fetch'); } + }; + var systemFetch; + var PlatformNode = /** @class */ (function () { + function PlatformNode() { + // tslint:disable-next-line:no-require-imports + this.util = require('util'); + // According to the spec, the built-in encoder can do only UTF-8 encoding. + // https://developer.mozilla.org/en-US/docs/Web/API/TextEncoder/TextEncoder + this.textEncoder = new this.util.TextEncoder(); + } + PlatformNode.prototype.fetch = function (path, requestInits) { + if (env().global.fetch != null) { + return env().global.fetch(path, requestInits); + } + if (systemFetch == null) { + systemFetch = getNodeFetch.importFetch(); + } + return systemFetch(path, requestInits); + }; + PlatformNode.prototype.now = function () { + var time = process.hrtime(); + return time[0] * 1000 + time[1] / 1000000; + }; + PlatformNode.prototype.encode = function (text, encoding) { + if (encoding !== 'utf-8' && encoding !== 'utf8') { + throw new Error("Node built-in encoder only supports utf-8, but got " + encoding); + } + return this.textEncoder.encode(text); + }; + PlatformNode.prototype.decode = function (bytes, encoding) { + if (bytes.length === 0) { + return ''; + } + return new this.util.TextDecoder(encoding).decode(bytes); + }; + return PlatformNode; + }()); + if (env().get('IS_NODE')) { + env().setPlatform('node', new PlatformNode()); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /* Type definitions for exporting and importing of models. */ + /** + * A map from Tensor dtype to number of bytes per element of the Tensor. + */ + var DTYPE_VALUE_SIZE_MAP = { + 'float32': 4, + 'int32': 4, + 'uint16': 2, + 'uint8': 1, + 'bool': 1, + }; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** Number of bytes reserved for the length of the string. (32bit integer). */ + var NUM_BYTES_STRING_LENGTH = 4; + /** + * Encode a map from names to weight values as an ArrayBuffer, along with an + * `Array` of `WeightsManifestEntry` as specification of the encoded weights. + * + * This function does not perform sharding. + * + * This function is the reverse of `decodeWeights`. + * + * @param tensors A map ("dict") from names to tensors. + * @param group Group to which the weights belong (optional). + * @returns A `Promise` of + * - A flat `ArrayBuffer` with all the binary values of the `Tensor`s + * concatenated. + * - An `Array` of `WeightManifestEntry`s, carrying information including + * tensor names, `dtype`s and shapes. + * @throws Error: on unsupported tensor `dtype`. + */ + function encodeWeights(tensors, group) { + return __awaiter(this, void 0, void 0, function () { + var specs, dataPromises, names, _loop_1, i, tensorValues; + var _this = this; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + specs = []; + dataPromises = []; + names = Array.isArray(tensors) ? + tensors.map(function (tensor) { return tensor.name; }) : + Object.keys(tensors); + _loop_1 = function (i) { + var name_1 = names[i]; + var t = Array.isArray(tensors) ? tensors[i].tensor : tensors[name_1]; + if (t.dtype !== 'float32' && t.dtype !== 'int32' && t.dtype !== 'bool' && + t.dtype !== 'string') { + throw new Error("Unsupported dtype in weight '" + name_1 + "': " + t.dtype); + } + var spec = { name: name_1, shape: t.shape, dtype: t.dtype }; + if (t.dtype === 'string') { + var utf8bytes = new Promise(function (resolve) { return __awaiter(_this, void 0, void 0, function () { + var vals, totalNumBytes, bytes, offset, i_1, val, bytesOfLength; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, t.bytes()]; + case 1: + vals = _a.sent(); + totalNumBytes = vals.reduce(function (p, c) { return p + c.length; }, 0) + + NUM_BYTES_STRING_LENGTH * vals.length; + bytes = new Uint8Array(totalNumBytes); + offset = 0; + for (i_1 = 0; i_1 < vals.length; i_1++) { + val = vals[i_1]; + bytesOfLength = new Uint8Array(new Uint32Array([val.length]).buffer); + bytes.set(bytesOfLength, offset); + offset += NUM_BYTES_STRING_LENGTH; + bytes.set(val, offset); + offset += val.length; + } + resolve(bytes); + return [2 /*return*/]; + } + }); + }); }); + dataPromises.push(utf8bytes); + } + else { + dataPromises.push(t.data()); + } + if (group != null) { + spec.group = group; + } + specs.push(spec); + }; + for (i = 0; i < names.length; ++i) { + _loop_1(i); + } + return [4 /*yield*/, Promise.all(dataPromises)]; + case 1: + tensorValues = _a.sent(); + return [2 /*return*/, { data: concatenateTypedArrays(tensorValues), specs: specs }]; + } + }); + }); + } + /** + * Decode flat ArrayBuffer as weights. + * + * This function does not handle sharding. + * + * This function is the reverse of `encodeWeights`. + * + * @param buffer A flat ArrayBuffer carrying the binary values of the tensors + * concatenated in the order specified in `specs`. + * @param specs Specifications of the names, dtypes and shapes of the tensors + * whose value are encoded by `buffer`. + * @return A map from tensor name to tensor value, with the names corresponding + * to names in `specs`. + * @throws Error, if any of the tensors has unsupported dtype. + */ + function decodeWeights(buffer, specs) { + // TODO(adarob, cais): Support quantization. + var out = {}; + var offset = 0; + var _loop_2 = function (spec) { + var name_2 = spec.name; + var dtype = spec.dtype; + var shape = spec.shape; + var size = sizeFromShape(shape); + var values = void 0; + if ('quantization' in spec) { + var quantization_1 = spec.quantization; + if (quantization_1.dtype !== 'uint8' && quantization_1.dtype !== 'uint16') { + throw new Error("Weight " + spec.name + " has unknown " + + ("quantization dtype " + quantization_1.dtype + ". ") + + "Supported quantization dtypes are: 'uint8' and 'uint16'."); + } + var quantizationSizeFactor = DTYPE_VALUE_SIZE_MAP[quantization_1.dtype]; + var byteBuffer = buffer.slice(offset, offset + size * quantizationSizeFactor); + var quantizedArray = (quantization_1.dtype === 'uint8') ? + new Uint8Array(byteBuffer) : + new Uint16Array(byteBuffer); + if (dtype === 'float32') { + values = Float32Array.from(quantizedArray, function (v) { return v * quantization_1.scale + quantization_1.min; }); + } + else if (dtype === 'int32') { + values = Int32Array.from(quantizedArray, function (v) { return Math.round(v * quantization_1.scale + quantization_1.min); }); + } + else { + throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); + } + offset += size * quantizationSizeFactor; + } + else if (dtype === 'string') { + var size_1 = sizeFromShape(spec.shape); + values = []; + for (var i = 0; i < size_1; i++) { + var byteLength = new Uint32Array(buffer.slice(offset, offset + NUM_BYTES_STRING_LENGTH))[0]; + offset += NUM_BYTES_STRING_LENGTH; + var bytes = new Uint8Array(buffer.slice(offset, offset + byteLength)); + values.push(bytes); + offset += byteLength; + } + } + else { + var dtypeFactor = DTYPE_VALUE_SIZE_MAP[dtype]; + var byteBuffer = buffer.slice(offset, offset + size * dtypeFactor); + if (dtype === 'float32') { + values = new Float32Array(byteBuffer); + } + else if (dtype === 'int32') { + values = new Int32Array(byteBuffer); + } + else if (dtype === 'bool') { + values = new Uint8Array(byteBuffer); + } + else { + throw new Error("Unsupported dtype in weight '" + name_2 + "': " + dtype); + } + offset += size * dtypeFactor; + } + out[name_2] = tensor(values, shape, dtype); + }; + for (var _i = 0, specs_1 = specs; _i < specs_1.length; _i++) { + var spec = specs_1[_i]; + _loop_2(spec); + } + return out; + } + /** + * Concatenate TypedArrays into an ArrayBuffer. + */ + function concatenateTypedArrays(xs) { + // TODO(adarob, cais): Support quantization. + if (xs === null) { + throw new Error("Invalid input value: " + JSON.stringify(xs)); + } + var totalByteLength = 0; + // `normalizedXs` is here for this reason: a `TypedArray`'s `buffer' + // can have a different byte length from that of the `TypedArray` itself, + // for example, when the `TypedArray` is created from an offset in an + // `ArrayBuffer`. `normliazedXs` holds `TypedArray`s whose `buffer`s match + // the `TypedArray` in byte length. If an element of `xs` does not show + // this property, a new `TypedArray` that satisfy this property will be + // constructed and pushed into `normalizedXs`. + var normalizedXs = []; + xs.forEach(function (x) { + totalByteLength += x.byteLength; + // tslint:disable:no-any + normalizedXs.push(x.byteLength === x.buffer.byteLength ? x : + new x.constructor(x)); + if (!(x instanceof Float32Array || x instanceof Int32Array || + x instanceof Uint8Array)) { + throw new Error("Unsupported TypedArray subtype: " + x.constructor.name); + } + // tslint:enable:no-any + }); + var y = new Uint8Array(totalByteLength); + var offset = 0; + normalizedXs.forEach(function (x) { + y.set(new Uint8Array(x.buffer), offset); + offset += x.byteLength; + }); + return y.buffer; + } + // Use Buffer on Node.js instead of Blob/atob/btoa + var useNodeBuffer = typeof Buffer !== 'undefined' && + (typeof Blob === 'undefined' || typeof atob === 'undefined' || + typeof btoa === 'undefined'); + /** + * Calculate the byte length of a JavaScript string. + * + * Note that a JavaScript string can contain wide characters, therefore the + * length of the string is not necessarily equal to the byte length. + * + * @param str Input string. + * @returns Byte length. + */ + function stringByteLength(str) { + if (useNodeBuffer) { + return Buffer.byteLength(str); + } + return new Blob([str]).size; + } + /** + * Encode an ArrayBuffer as a base64 encoded string. + * + * @param buffer `ArrayBuffer` to be converted. + * @returns A string that base64-encodes `buffer`. + */ + function arrayBufferToBase64String(buffer) { + if (useNodeBuffer) { + return Buffer.from(buffer).toString('base64'); + } + return btoa(String.fromCharCode.apply(null, new Uint8Array(buffer))); + } + /** + * Decode a base64 string as an ArrayBuffer. + * + * @param str Base64 string. + * @returns Decoded `ArrayBuffer`. + */ + function base64StringToArrayBuffer(str) { + if (useNodeBuffer) { + var buf = Buffer.from(str, 'base64'); + return buf.buffer.slice(buf.byteOffset, buf.byteOffset + buf.byteLength); + } + var s = atob(str); + var buffer = new Uint8Array(s.length); + for (var i = 0; i < s.length; ++i) { + buffer.set([s.charCodeAt(i)], i); + } + return buffer.buffer; + } + /** + * Concatenate a number of ArrayBuffers into one. + * + * @param buffers A number of array buffers to concatenate. + * @returns Result of concatenating `buffers` in order. + */ + function concatenateArrayBuffers(buffers) { + var totalByteLength = 0; + buffers.forEach(function (buffer) { + totalByteLength += buffer.byteLength; + }); + var temp = new Uint8Array(totalByteLength); + var offset = 0; + buffers.forEach(function (buffer) { + temp.set(new Uint8Array(buffer), offset); + offset += buffer.byteLength; + }); + return temp.buffer; + } + /** + * Get the basename of a path. + * + * Behaves in a way analogous to Linux's basename command. + * + * @param path + */ + function basename(path) { + var SEPARATOR = '/'; + path = path.trim(); + while (path.endsWith(SEPARATOR)) { + path = path.slice(0, path.length - 1); + } + var items = path.split(SEPARATOR); + return items[items.length - 1]; + } + /** + * Populate ModelArtifactsInfo fields for a model with JSON topology. + * @param modelArtifacts + * @returns A ModelArtifactsInfo object. + */ + function getModelArtifactsInfoForJSON(modelArtifacts) { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('Expected JSON model topology, received ArrayBuffer.'); + } + return { + dateSaved: new Date(), + modelTopologyType: 'JSON', + modelTopologyBytes: modelArtifacts.modelTopology == null ? + 0 : + stringByteLength(JSON.stringify(modelArtifacts.modelTopology)), + weightSpecsBytes: modelArtifacts.weightSpecs == null ? + 0 : + stringByteLength(JSON.stringify(modelArtifacts.weightSpecs)), + weightDataBytes: modelArtifacts.weightData == null ? + 0 : + modelArtifacts.weightData.byteLength, + }; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var IORouterRegistry = /** @class */ (function () { + function IORouterRegistry() { + this.saveRouters = []; + this.loadRouters = []; + } + IORouterRegistry.getInstance = function () { + if (IORouterRegistry.instance == null) { + IORouterRegistry.instance = new IORouterRegistry(); + } + return IORouterRegistry.instance; + }; + /** + * Register a save-handler router. + * + * @param saveRouter A function that maps a URL-like string onto an instance + * of `IOHandler` with the `save` method defined or `null`. + */ + IORouterRegistry.registerSaveRouter = function (saveRouter) { + IORouterRegistry.getInstance().saveRouters.push(saveRouter); + }; + /** + * Register a load-handler router. + * + * @param loadRouter A function that maps a URL-like string onto an instance + * of `IOHandler` with the `load` method defined or `null`. + */ + IORouterRegistry.registerLoadRouter = function (loadRouter) { + IORouterRegistry.getInstance().loadRouters.push(loadRouter); + }; + /** + * Look up IOHandler for saving, given a URL-like string. + * + * @param url + * @returns If only one match is found, an instance of IOHandler with the + * `save` method defined. If no match is found, `null`. + * @throws Error, if more than one match is found. + */ + IORouterRegistry.getSaveHandlers = function (url) { + return IORouterRegistry.getHandlers(url, 'save'); + }; + /** + * Look up IOHandler for loading, given a URL-like string. + * + * @param url + * @param onProgress Optional, progress callback function, fired periodically + * before the load is completed. + * @returns All valid handlers for `url`, given the currently registered + * handler routers. + */ + IORouterRegistry.getLoadHandlers = function (url, onProgress) { + return IORouterRegistry.getHandlers(url, 'load', onProgress); + }; + IORouterRegistry.getHandlers = function (url, handlerType, onProgress) { + var validHandlers = []; + var routers = handlerType === 'load' ? + IORouterRegistry.getInstance().loadRouters : + IORouterRegistry.getInstance().saveRouters; + routers.forEach(function (router) { + var handler = router(url, onProgress); + if (handler !== null) { + validHandlers.push(handler); + } + }); + return validHandlers; + }; + return IORouterRegistry; + }()); + var registerSaveRouter = function (loudRouter) { + return IORouterRegistry.registerSaveRouter(loudRouter); + }; + var registerLoadRouter = function (loudRouter) { + return IORouterRegistry.registerLoadRouter(loudRouter); + }; + var getSaveHandlers = function (url) { + return IORouterRegistry.getSaveHandlers(url); + }; + var getLoadHandlers = function (url, onProgress) { + return IORouterRegistry.getLoadHandlers(url, onProgress); + }; + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var URL_SCHEME_SUFFIX = '://'; + var ModelStoreManagerRegistry = /** @class */ (function () { + function ModelStoreManagerRegistry() { + this.managers = {}; + } + ModelStoreManagerRegistry.getInstance = function () { + if (ModelStoreManagerRegistry.instance == null) { + ModelStoreManagerRegistry.instance = new ModelStoreManagerRegistry(); + } + return ModelStoreManagerRegistry.instance; + }; + /** + * Register a save-handler router. + * + * @param saveRouter A function that maps a URL-like string onto an instance + * of `IOHandler` with the `save` method defined or `null`. + */ + ModelStoreManagerRegistry.registerManager = function (scheme, manager) { + assert(scheme != null, function () { return 'scheme must not be undefined or null.'; }); + if (scheme.endsWith(URL_SCHEME_SUFFIX)) { + scheme = scheme.slice(0, scheme.indexOf(URL_SCHEME_SUFFIX)); + } + assert(scheme.length > 0, function () { return 'scheme must not be an empty string.'; }); + var registry = ModelStoreManagerRegistry.getInstance(); + assert(registry.managers[scheme] == null, function () { return "A model store manager is already registered for scheme '" + scheme + "'."; }); + registry.managers[scheme] = manager; + }; + ModelStoreManagerRegistry.getManager = function (scheme) { + var manager = this.getInstance().managers[scheme]; + if (manager == null) { + throw new Error("Cannot find model manager for scheme '" + scheme + "'"); + } + return manager; + }; + ModelStoreManagerRegistry.getSchemes = function () { + return Object.keys(this.getInstance().managers); + }; + return ModelStoreManagerRegistry; + }()); + /** + * Helper method for parsing a URL string into a scheme and a path. + * + * @param url E.g., 'localstorage://my-model' + * @returns A dictionary with two fields: scheme and path. + * Scheme: e.g., 'localstorage' in the example above. + * Path: e.g., 'my-model' in the example above. + */ + function parseURL(url) { + if (url.indexOf(URL_SCHEME_SUFFIX) === -1) { + throw new Error("The url string provided does not contain a scheme. " + + "Supported schemes are: " + + ("" + ModelStoreManagerRegistry.getSchemes().join(','))); + } + return { + scheme: url.split(URL_SCHEME_SUFFIX)[0], + path: url.split(URL_SCHEME_SUFFIX)[1], + }; + } + function cloneModelInternal(sourceURL, destURL, deleteSource) { + if (deleteSource === void 0) { deleteSource = false; } + return __awaiter(this, void 0, void 0, function () { + var loadHandlers, loadHandler, saveHandlers, saveHandler, sourceScheme, sourcePath, sameMedium, modelArtifacts, saveResult; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + assert(sourceURL !== destURL, function () { return "Old path and new path are the same: '" + sourceURL + "'"; }); + loadHandlers = IORouterRegistry.getLoadHandlers(sourceURL); + assert(loadHandlers.length > 0, function () { return "Copying failed because no load handler is found for source URL " + sourceURL + "."; }); + assert(loadHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " + + ("load handlers for source URL " + sourceURL + "."); }); + loadHandler = loadHandlers[0]; + saveHandlers = IORouterRegistry.getSaveHandlers(destURL); + assert(saveHandlers.length > 0, function () { return "Copying failed because no save handler is found for destination " + + ("URL " + destURL + "."); }); + assert(saveHandlers.length < 2, function () { return "Copying failed because more than one (" + loadHandlers.length + ") " + + ("save handlers for destination URL " + destURL + "."); }); + saveHandler = saveHandlers[0]; + sourceScheme = parseURL(sourceURL).scheme; + sourcePath = parseURL(sourceURL).path; + sameMedium = sourceScheme === parseURL(sourceURL).scheme; + return [4 /*yield*/, loadHandler.load()]; + case 1: + modelArtifacts = _a.sent(); + if (!(deleteSource && sameMedium)) return [3 /*break*/, 3]; + return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme) + .removeModel(sourcePath)]; + case 2: + _a.sent(); + _a.label = 3; + case 3: return [4 /*yield*/, saveHandler.save(modelArtifacts)]; + case 4: + saveResult = _a.sent(); + if (!(deleteSource && !sameMedium)) return [3 /*break*/, 6]; + return [4 /*yield*/, ModelStoreManagerRegistry.getManager(sourceScheme) + .removeModel(sourcePath)]; + case 5: + _a.sent(); + _a.label = 6; + case 6: return [2 /*return*/, saveResult.modelArtifactsInfo]; + } + }); + }); + } + /** + * List all models stored in registered storage mediums. + * + * For a web browser environment, the registered mediums are Local Storage and + * IndexedDB. + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Delete the model. + * await tf.io.removeModel('localstorage://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * ``` + * + * @returns A `Promise` of a dictionary mapping URLs of existing models to + * their model artifacts info. URLs include medium-specific schemes, e.g., + * 'indexeddb://my/model/1'. Model artifacts info include type of the + * model's topology, byte sizes of the topology, weights, etc. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function listModels() { + return __awaiter(this, void 0, void 0, function () { + var schemes, out, _i, schemes_1, scheme, schemeOut, path, url; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + schemes = ModelStoreManagerRegistry.getSchemes(); + out = {}; + _i = 0, schemes_1 = schemes; + _a.label = 1; + case 1: + if (!(_i < schemes_1.length)) return [3 /*break*/, 4]; + scheme = schemes_1[_i]; + return [4 /*yield*/, ModelStoreManagerRegistry.getManager(scheme).listModels()]; + case 2: + schemeOut = _a.sent(); + for (path in schemeOut) { + url = scheme + URL_SCHEME_SUFFIX + path; + out[url] = schemeOut[path]; + } + _a.label = 3; + case 3: + _i++; + return [3 /*break*/, 1]; + case 4: return [2 /*return*/, out]; + } + }); + }); + } + /** + * Remove a model specified by URL from a reigstered storage medium. + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Delete the model. + * await tf.io.removeModel('localstorage://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * ``` + * + * @param url A URL to a stored model, with a scheme prefix, e.g., + * 'localstorage://my-model-1', 'indexeddb://my/model/2'. + * @returns ModelArtifactsInfo of the deleted model (if and only if deletion + * is successful). + * @throws Error if deletion fails, e.g., if no model exists at `path`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function removeModel(url) { + return __awaiter(this, void 0, void 0, function () { + var schemeAndPath, manager; + return __generator(this, function (_a) { + schemeAndPath = parseURL(url); + manager = ModelStoreManagerRegistry.getManager(schemeAndPath.scheme); + return [2 /*return*/, manager.removeModel(schemeAndPath.path)]; + }); + }); + } + /** + * Copy a model from one URL to another. + * + * This function supports: + * + * 1. Copying within a storage medium, e.g., + * `tf.io.copyModel('localstorage://model-1', 'localstorage://model-2')` + * 2. Copying between two storage mediums, e.g., + * `tf.io.copyModel('localstorage://model-1', 'indexeddb://model-1')` + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Copy the model, from Local Storage to IndexedDB. + * await tf.io.copyModel( + * 'localstorage://demo/management/model1', + * 'indexeddb://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Remove both models. + * await tf.io.removeModel('localstorage://demo/management/model1'); + * await tf.io.removeModel('indexeddb://demo/management/model1'); + * ``` + * + * @param sourceURL Source URL of copying. + * @param destURL Destination URL of copying. + * @returns ModelArtifactsInfo of the copied model (if and only if copying + * is successful). + * @throws Error if copying fails, e.g., if no model exists at `sourceURL`, or + * if `oldPath` and `newPath` are identical. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function copyModel(sourceURL, destURL) { + return __awaiter(this, void 0, void 0, function () { + var deleteSource; + return __generator(this, function (_a) { + deleteSource = false; + return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)]; + }); + }); + } + /** + * Move a model from one URL to another. + * + * This function supports: + * + * 1. Moving within a storage medium, e.g., + * `tf.io.moveModel('localstorage://model-1', 'localstorage://model-2')` + * 2. Moving between two storage mediums, e.g., + * `tf.io.moveModel('localstorage://model-1', 'indexeddb://model-1')` + * + * ```js + * // First create and save a model. + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * await model.save('localstorage://demo/management/model1'); + * + * // Then list existing models. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Move the model, from Local Storage to IndexedDB. + * await tf.io.moveModel( + * 'localstorage://demo/management/model1', + * 'indexeddb://demo/management/model1'); + * + * // List models again. + * console.log(JSON.stringify(await tf.io.listModels())); + * + * // Remove the moved model. + * await tf.io.removeModel('indexeddb://demo/management/model1'); + * ``` + * + * @param sourceURL Source URL of moving. + * @param destURL Destination URL of moving. + * @returns ModelArtifactsInfo of the copied model (if and only if copying + * is successful). + * @throws Error if moving fails, e.g., if no model exists at `sourceURL`, or + * if `oldPath` and `newPath` are identical. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Management', + * namespace: 'io', + * ignoreCI: true + * } + */ + function moveModel(sourceURL, destURL) { + return __awaiter(this, void 0, void 0, function () { + var deleteSource; + return __generator(this, function (_a) { + deleteSource = true; + return [2 /*return*/, cloneModelInternal(sourceURL, destURL, deleteSource)]; + }); + }); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DATABASE_NAME = 'tensorflowjs'; + var DATABASE_VERSION = 1; + // Model data and ModelArtifactsInfo (metadata) are stored in two separate + // stores for efficient access of the list of stored models and their metadata. + // 1. The object store for model data: topology, weights and weight manifests. + var MODEL_STORE_NAME = 'models_store'; + // 2. The object store for ModelArtifactsInfo, including meta-information such + // as the type of topology (JSON vs binary), byte size of the topology, byte + // size of the weights, etc. + var INFO_STORE_NAME = 'model_info_store'; + function getIndexedDBFactory() { + if (!env().getBool('IS_BROWSER')) { + // TODO(cais): Add more info about what IOHandler subtypes are available. + // Maybe point to a doc page on the web and/or automatically determine + // the available IOHandlers and print them in the error message. + throw new Error('Failed to obtain IndexedDB factory because the current environment' + + 'is not a web browser.'); + } + // tslint:disable-next-line:no-any + var theWindow = window; + var factory = theWindow.indexedDB || theWindow.mozIndexedDB || + theWindow.webkitIndexedDB || theWindow.msIndexedDB || + theWindow.shimIndexedDB; + if (factory == null) { + throw new Error('The current browser does not appear to support IndexedDB.'); + } + return factory; + } + function setUpDatabase(openRequest) { + var db = openRequest.result; + db.createObjectStore(MODEL_STORE_NAME, { keyPath: 'modelPath' }); + db.createObjectStore(INFO_STORE_NAME, { keyPath: 'modelPath' }); + } + /** + * IOHandler subclass: Browser IndexedDB. + * + * See the doc string of `browserIndexedDB` for more details. + */ + var BrowserIndexedDB = /** @class */ (function () { + function BrowserIndexedDB(modelPath) { + this.indexedDB = getIndexedDBFactory(); + if (modelPath == null || !modelPath) { + throw new Error('For IndexedDB, modelPath must not be null, undefined or empty.'); + } + this.modelPath = modelPath; + } + BrowserIndexedDB.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + // TODO(cais): Support saving GraphDef models. + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + + 'in binary formats yet.'); + } + return [2 /*return*/, this.databaseAction(this.modelPath, modelArtifacts)]; + }); + }); + }; + BrowserIndexedDB.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.databaseAction(this.modelPath)]; + }); + }); + }; + /** + * Perform database action to put model artifacts into or read model artifacts + * from IndexedDB object store. + * + * Whether the action is put or get depends on whether `modelArtifacts` is + * specified. If it is specified, the action will be put; otherwise the action + * will be get. + * + * @param modelPath A unique string path for the model. + * @param modelArtifacts If specified, it will be the model artifacts to be + * stored in IndexedDB. + * @returns A `Promise` of `SaveResult`, if the action is put, or a `Promise` + * of `ModelArtifacts`, if the action is get. + */ + BrowserIndexedDB.prototype.databaseAction = function (modelPath, modelArtifacts) { + var _this = this; + return new Promise(function (resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; + openRequest.onsuccess = function () { + var db = openRequest.result; + if (modelArtifacts == null) { + // Read model out from object store. + var modelTx = db.transaction(MODEL_STORE_NAME, 'readonly'); + var modelStore = modelTx.objectStore(MODEL_STORE_NAME); + var getRequest_1 = modelStore.get(_this.modelPath); + getRequest_1.onsuccess = function () { + if (getRequest_1.result == null) { + db.close(); + return reject(new Error("Cannot find model with path '" + _this.modelPath + "' " + + "in IndexedDB.")); + } + else { + resolve(getRequest_1.result.modelArtifacts); + } + }; + getRequest_1.onerror = function (error) { + db.close(); + return reject(getRequest_1.error); + }; + modelTx.oncomplete = function () { return db.close(); }; + } + else { + // Put model into object store. + var modelArtifactsInfo_1 = getModelArtifactsInfoForJSON(modelArtifacts); + // First, put ModelArtifactsInfo into info store. + var infoTx_1 = db.transaction(INFO_STORE_NAME, 'readwrite'); + var infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); + var putInfoRequest_1 = infoStore_1.put({ modelPath: _this.modelPath, modelArtifactsInfo: modelArtifactsInfo_1 }); + var modelTx_1; + putInfoRequest_1.onsuccess = function () { + // Second, put model data into model store. + modelTx_1 = db.transaction(MODEL_STORE_NAME, 'readwrite'); + var modelStore = modelTx_1.objectStore(MODEL_STORE_NAME); + var putModelRequest = modelStore.put({ + modelPath: _this.modelPath, + modelArtifacts: modelArtifacts, + modelArtifactsInfo: modelArtifactsInfo_1 + }); + putModelRequest.onsuccess = function () { return resolve({ modelArtifactsInfo: modelArtifactsInfo_1 }); }; + putModelRequest.onerror = function (error) { + // If the put-model request fails, roll back the info entry as + // well. + infoStore_1 = infoTx_1.objectStore(INFO_STORE_NAME); + var deleteInfoRequest = infoStore_1.delete(_this.modelPath); + deleteInfoRequest.onsuccess = function () { + db.close(); + return reject(putModelRequest.error); + }; + deleteInfoRequest.onerror = function (error) { + db.close(); + return reject(putModelRequest.error); + }; + }; + }; + putInfoRequest_1.onerror = function (error) { + db.close(); + return reject(putInfoRequest_1.error); + }; + infoTx_1.oncomplete = function () { + if (modelTx_1 == null) { + db.close(); + } + else { + modelTx_1.oncomplete = function () { return db.close(); }; + } + }; + } + }; + openRequest.onerror = function (error) { return reject(openRequest.error); }; + }); + }; + BrowserIndexedDB.URL_SCHEME = 'indexeddb://'; + return BrowserIndexedDB; + }()); + var indexedDBRouter = function (url) { + if (!env().getBool('IS_BROWSER')) { + return null; + } + else { + if (!Array.isArray(url) && url.startsWith(BrowserIndexedDB.URL_SCHEME)) { + return browserIndexedDB(url.slice(BrowserIndexedDB.URL_SCHEME.length)); + } + else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(indexedDBRouter); + IORouterRegistry.registerLoadRouter(indexedDBRouter); + /** + * Creates a browser IndexedDB IOHandler for saving and loading models. + * + * ```js + * const model = tf.sequential(); + * model.add( + * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); + * + * const saveResult = await model.save('indexeddb://MyModel')); + * console.log(saveResult); + * ``` + * + * @param modelPath A unique identifier for the model to be saved. Must be a + * non-empty string. + * @returns An instance of `BrowserIndexedDB` (sublcass of `IOHandler`), + * which can be used with, e.g., `tf.Model.save`. + */ + function browserIndexedDB(modelPath) { + return new BrowserIndexedDB(modelPath); + } + function maybeStripScheme(key) { + return key.startsWith(BrowserIndexedDB.URL_SCHEME) ? + key.slice(BrowserIndexedDB.URL_SCHEME.length) : + key; + } + var BrowserIndexedDBManager = /** @class */ (function () { + function BrowserIndexedDBManager() { + this.indexedDB = getIndexedDBFactory(); + } + BrowserIndexedDBManager.prototype.listModels = function () { + return __awaiter(this, void 0, void 0, function () { + var _this = this; + return __generator(this, function (_a) { + return [2 /*return*/, new Promise(function (resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; + openRequest.onsuccess = function () { + var db = openRequest.result; + var tx = db.transaction(INFO_STORE_NAME, 'readonly'); + var store = tx.objectStore(INFO_STORE_NAME); + // tslint:disable:max-line-length + // Need to cast `store` as `any` here because TypeScript's DOM + // library does not have the `getAll()` method even though the + // method is supported in the latest version of most mainstream + // browsers: + // https://developer.mozilla.org/en-US/docs/Web/API/IDBObjectStore/getAll + // tslint:enable:max-line-length + // tslint:disable-next-line:no-any + var getAllInfoRequest = store.getAll(); + getAllInfoRequest.onsuccess = function () { + var out = {}; + for (var _i = 0, _a = getAllInfoRequest.result; _i < _a.length; _i++) { + var item = _a[_i]; + out[item.modelPath] = item.modelArtifactsInfo; + } + resolve(out); + }; + getAllInfoRequest.onerror = function (error) { + db.close(); + return reject(getAllInfoRequest.error); + }; + tx.oncomplete = function () { return db.close(); }; + }; + openRequest.onerror = function (error) { return reject(openRequest.error); }; + })]; + }); + }); + }; + BrowserIndexedDBManager.prototype.removeModel = function (path) { + return __awaiter(this, void 0, void 0, function () { + var _this = this; + return __generator(this, function (_a) { + path = maybeStripScheme(path); + return [2 /*return*/, new Promise(function (resolve, reject) { + var openRequest = _this.indexedDB.open(DATABASE_NAME, DATABASE_VERSION); + openRequest.onupgradeneeded = function () { return setUpDatabase(openRequest); }; + openRequest.onsuccess = function () { + var db = openRequest.result; + var infoTx = db.transaction(INFO_STORE_NAME, 'readwrite'); + var infoStore = infoTx.objectStore(INFO_STORE_NAME); + var getInfoRequest = infoStore.get(path); + var modelTx; + getInfoRequest.onsuccess = function () { + if (getInfoRequest.result == null) { + db.close(); + return reject(new Error("Cannot find model with path '" + path + "' " + + "in IndexedDB.")); + } + else { + // First, delete the entry in the info store. + var deleteInfoRequest = infoStore.delete(path); + var deleteModelData_1 = function () { + // Second, delete the entry in the model store. + modelTx = db.transaction(MODEL_STORE_NAME, 'readwrite'); + var modelStore = modelTx.objectStore(MODEL_STORE_NAME); + var deleteModelRequest = modelStore.delete(path); + deleteModelRequest.onsuccess = function () { + return resolve(getInfoRequest.result.modelArtifactsInfo); + }; + deleteModelRequest.onerror = function (error) { + return reject(getInfoRequest.error); + }; + }; + // Proceed with deleting model data regardless of whether deletion + // of info data succeeds or not. + deleteInfoRequest.onsuccess = deleteModelData_1; + deleteInfoRequest.onerror = function (error) { + deleteModelData_1(); + db.close(); + return reject(getInfoRequest.error); + }; + } + }; + getInfoRequest.onerror = function (error) { + db.close(); + return reject(getInfoRequest.error); + }; + infoTx.oncomplete = function () { + if (modelTx == null) { + db.close(); + } + else { + modelTx.oncomplete = function () { return db.close(); }; + } + }; + }; + openRequest.onerror = function (error) { return reject(openRequest.error); }; + })]; + }); + }); + }; + return BrowserIndexedDBManager; + }()); + if (env().getBool('IS_BROWSER')) { + // Wrap the construction and registration, to guard against browsers that + // don't support Local Storage. + try { + ModelStoreManagerRegistry.registerManager(BrowserIndexedDB.URL_SCHEME, new BrowserIndexedDBManager()); + } + catch (err) { + } + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PATH_SEPARATOR = '/'; + var PATH_PREFIX = 'tensorflowjs_models'; + var INFO_SUFFIX = 'info'; + var MODEL_TOPOLOGY_SUFFIX = 'model_topology'; + var WEIGHT_SPECS_SUFFIX = 'weight_specs'; + var WEIGHT_DATA_SUFFIX = 'weight_data'; + var MODEL_METADATA_SUFFIX = 'model_metadata'; + function getModelKeys(path) { + return { + info: [PATH_PREFIX, path, INFO_SUFFIX].join(PATH_SEPARATOR), + topology: [PATH_PREFIX, path, MODEL_TOPOLOGY_SUFFIX].join(PATH_SEPARATOR), + weightSpecs: [PATH_PREFIX, path, WEIGHT_SPECS_SUFFIX].join(PATH_SEPARATOR), + weightData: [PATH_PREFIX, path, WEIGHT_DATA_SUFFIX].join(PATH_SEPARATOR), + modelMetadata: [PATH_PREFIX, path, MODEL_METADATA_SUFFIX].join(PATH_SEPARATOR) + }; + } + /** + * Get model path from a local-storage key. + * + * E.g., 'tensorflowjs_models/my/model/1/info' --> 'my/model/1' + * + * @param key + */ + function getModelPathFromKey(key) { + var items = key.split(PATH_SEPARATOR); + if (items.length < 3) { + throw new Error("Invalid key format: " + key); + } + return items.slice(1, items.length - 1).join(PATH_SEPARATOR); + } + function maybeStripScheme$1(key) { + return key.startsWith(BrowserLocalStorage.URL_SCHEME) ? + key.slice(BrowserLocalStorage.URL_SCHEME.length) : + key; + } + /** + * IOHandler subclass: Browser Local Storage. + * + * See the doc string to `browserLocalStorage` for more details. + */ + var BrowserLocalStorage = /** @class */ (function () { + function BrowserLocalStorage(modelPath) { + if (!env().getBool('IS_BROWSER') || + typeof window.localStorage === 'undefined') { + // TODO(cais): Add more info about what IOHandler subtypes are + // available. + // Maybe point to a doc page on the web and/or automatically determine + // the available IOHandlers and print them in the error message. + throw new Error('The current environment does not support local storage.'); + } + this.LS = window.localStorage; + if (modelPath == null || !modelPath) { + throw new Error('For local storage, modelPath must not be null, undefined or empty.'); + } + this.modelPath = modelPath; + this.keys = getModelKeys(this.modelPath); + } + /** + * Save model artifacts to browser local storage. + * + * See the documentation to `browserLocalStorage` for details on the saved + * artifacts. + * + * @param modelArtifacts The model artifacts to be stored. + * @returns An instance of SaveResult. + */ + BrowserLocalStorage.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + var topology, weightSpecs, modelArtifactsInfo; + return __generator(this, function (_a) { + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('BrowserLocalStorage.save() does not support saving model topology ' + + 'in binary formats yet.'); + } + else { + topology = JSON.stringify(modelArtifacts.modelTopology); + weightSpecs = JSON.stringify(modelArtifacts.weightSpecs); + modelArtifactsInfo = getModelArtifactsInfoForJSON(modelArtifacts); + try { + this.LS.setItem(this.keys.info, JSON.stringify(modelArtifactsInfo)); + this.LS.setItem(this.keys.topology, topology); + this.LS.setItem(this.keys.weightSpecs, weightSpecs); + this.LS.setItem(this.keys.weightData, arrayBufferToBase64String(modelArtifacts.weightData)); + this.LS.setItem(this.keys.modelMetadata, JSON.stringify({ + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + userDefinedMetadata: modelArtifacts.userDefinedMetadata + })); + return [2 /*return*/, { modelArtifactsInfo: modelArtifactsInfo }]; + } + catch (err) { + // If saving failed, clean up all items saved so far. + this.LS.removeItem(this.keys.info); + this.LS.removeItem(this.keys.topology); + this.LS.removeItem(this.keys.weightSpecs); + this.LS.removeItem(this.keys.weightData); + this.LS.removeItem(this.keys.modelMetadata); + throw new Error("Failed to save model '" + this.modelPath + "' to local storage: " + + "size quota being exceeded is a possible cause of this failure: " + + ("modelTopologyBytes=" + modelArtifactsInfo.modelTopologyBytes + ", ") + + ("weightSpecsBytes=" + modelArtifactsInfo.weightSpecsBytes + ", ") + + ("weightDataBytes=" + modelArtifactsInfo.weightDataBytes + ".")); + } + } + return [2 /*return*/]; + }); + }); + }; + /** + * Load a model from local storage. + * + * See the documentation to `browserLocalStorage` for details on the saved + * artifacts. + * + * @returns The loaded model (if loading succeeds). + */ + BrowserLocalStorage.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + var info, out, topology, weightSpecs, metadataString, metadata, weightDataBase64; + return __generator(this, function (_a) { + info = JSON.parse(this.LS.getItem(this.keys.info)); + if (info == null) { + throw new Error("In local storage, there is no model with name '" + this.modelPath + "'"); + } + if (info.modelTopologyType !== 'JSON') { + throw new Error('BrowserLocalStorage does not support loading non-JSON model ' + + 'topology yet.'); + } + out = {}; + topology = JSON.parse(this.LS.getItem(this.keys.topology)); + if (topology == null) { + throw new Error("In local storage, the topology of model '" + this.modelPath + "' " + + "is missing."); + } + out.modelTopology = topology; + weightSpecs = JSON.parse(this.LS.getItem(this.keys.weightSpecs)); + if (weightSpecs == null) { + throw new Error("In local storage, the weight specs of model '" + this.modelPath + "' " + + "are missing."); + } + out.weightSpecs = weightSpecs; + metadataString = this.LS.getItem(this.keys.modelMetadata); + if (metadataString != null) { + metadata = JSON.parse(metadataString); + out.format = metadata['format']; + out.generatedBy = metadata['generatedBy']; + out.convertedBy = metadata['convertedBy']; + out.userDefinedMetadata = metadata['userDefinedMetadata']; + } + weightDataBase64 = this.LS.getItem(this.keys.weightData); + if (weightDataBase64 == null) { + throw new Error("In local storage, the binary weight values of model " + + ("'" + this.modelPath + "' are missing.")); + } + out.weightData = base64StringToArrayBuffer(weightDataBase64); + return [2 /*return*/, out]; + }); + }); + }; + BrowserLocalStorage.URL_SCHEME = 'localstorage://'; + return BrowserLocalStorage; + }()); + var localStorageRouter = function (url) { + if (!env().getBool('IS_BROWSER')) { + return null; + } + else { + if (!Array.isArray(url) && url.startsWith(BrowserLocalStorage.URL_SCHEME)) { + return browserLocalStorage(url.slice(BrowserLocalStorage.URL_SCHEME.length)); + } + else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(localStorageRouter); + IORouterRegistry.registerLoadRouter(localStorageRouter); + /** + * Factory function for local storage IOHandler. + * + * This `IOHandler` supports both `save` and `load`. + * + * For each model's saved artifacts, four items are saved to local storage. + * - `${PATH_SEPARATOR}/${modelPath}/info`: Contains meta-info about the + * model, such as date saved, type of the topology, size in bytes, etc. + * - `${PATH_SEPARATOR}/${modelPath}/topology`: Model topology. For Keras- + * style models, this is a stringized JSON. + * - `${PATH_SEPARATOR}/${modelPath}/weight_specs`: Weight specs of the + * model, can be used to decode the saved binary weight values (see + * item below). + * - `${PATH_SEPARATOR}/${modelPath}/weight_data`: Concatenated binary + * weight values, stored as a base64-encoded string. + * + * Saving may throw an `Error` if the total size of the artifacts exceed the + * browser-specific quota. + * + * @param modelPath A unique identifier for the model to be saved. Must be a + * non-empty string. + * @returns An instance of `IOHandler`, which can be used with, e.g., + * `tf.Model.save`. + */ + function browserLocalStorage(modelPath) { + return new BrowserLocalStorage(modelPath); + } + var BrowserLocalStorageManager = /** @class */ (function () { + function BrowserLocalStorageManager() { + assert(env().getBool('IS_BROWSER'), function () { return 'Current environment is not a web browser'; }); + assert(typeof window.localStorage !== 'undefined', function () { return 'Current browser does not appear to support localStorage'; }); + this.LS = window.localStorage; + } + BrowserLocalStorageManager.prototype.listModels = function () { + return __awaiter(this, void 0, void 0, function () { + var out, prefix, suffix, i, key, modelPath; + return __generator(this, function (_a) { + out = {}; + prefix = PATH_PREFIX + PATH_SEPARATOR; + suffix = PATH_SEPARATOR + INFO_SUFFIX; + for (i = 0; i < this.LS.length; ++i) { + key = this.LS.key(i); + if (key.startsWith(prefix) && key.endsWith(suffix)) { + modelPath = getModelPathFromKey(key); + out[modelPath] = JSON.parse(this.LS.getItem(key)); + } + } + return [2 /*return*/, out]; + }); + }); + }; + BrowserLocalStorageManager.prototype.removeModel = function (path) { + return __awaiter(this, void 0, void 0, function () { + var keys, info; + return __generator(this, function (_a) { + path = maybeStripScheme$1(path); + keys = getModelKeys(path); + if (this.LS.getItem(keys.info) == null) { + throw new Error("Cannot find model at path '" + path + "'"); + } + info = JSON.parse(this.LS.getItem(keys.info)); + this.LS.removeItem(keys.info); + this.LS.removeItem(keys.topology); + this.LS.removeItem(keys.weightSpecs); + this.LS.removeItem(keys.weightData); + return [2 /*return*/, info]; + }); + }); + }; + return BrowserLocalStorageManager; + }()); + if (env().getBool('IS_BROWSER')) { + // Wrap the construction and registration, to guard against browsers that + // don't support Local Storage. + try { + ModelStoreManagerRegistry.registerManager(BrowserLocalStorage.URL_SCHEME, new BrowserLocalStorageManager()); + } + catch (err) { + } + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var DEFAULT_FILE_NAME_PREFIX = 'model'; + var DEFAULT_JSON_EXTENSION_NAME = '.json'; + var DEFAULT_WEIGHT_DATA_EXTENSION_NAME = '.weights.bin'; + function defer(f) { + return new Promise(function (resolve) { return setTimeout(resolve); }).then(f); + } + var BrowserDownloads = /** @class */ (function () { + function BrowserDownloads(fileNamePrefix) { + if (!env().getBool('IS_BROWSER')) { + // TODO(cais): Provide info on what IOHandlers are available under the + // current environment. + throw new Error('browserDownloads() cannot proceed because the current environment ' + + 'is not a browser.'); + } + if (fileNamePrefix.startsWith(BrowserDownloads.URL_SCHEME)) { + fileNamePrefix = fileNamePrefix.slice(BrowserDownloads.URL_SCHEME.length); + } + if (fileNamePrefix == null || fileNamePrefix.length === 0) { + fileNamePrefix = DEFAULT_FILE_NAME_PREFIX; + } + this.modelTopologyFileName = fileNamePrefix + DEFAULT_JSON_EXTENSION_NAME; + this.weightDataFileName = + fileNamePrefix + DEFAULT_WEIGHT_DATA_EXTENSION_NAME; + } + BrowserDownloads.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + var weightsURL, weightsManifest, modelTopologyAndWeightManifest, modelTopologyAndWeightManifestURL, jsonAnchor_1, weightDataAnchor_1; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (typeof (document) === 'undefined') { + throw new Error('Browser downloads are not supported in ' + + 'this environment since `document` is not present'); + } + weightsURL = window.URL.createObjectURL(new Blob([modelArtifacts.weightData], { type: 'application/octet-stream' })); + if (!(modelArtifacts.modelTopology instanceof ArrayBuffer)) return [3 /*break*/, 1]; + throw new Error('BrowserDownloads.save() does not support saving model topology ' + + 'in binary formats yet.'); + case 1: + weightsManifest = [{ + paths: ['./' + this.weightDataFileName], + weights: modelArtifacts.weightSpecs + }]; + modelTopologyAndWeightManifest = { + modelTopology: modelArtifacts.modelTopology, + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + weightsManifest: weightsManifest + }; + modelTopologyAndWeightManifestURL = window.URL.createObjectURL(new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: 'application/json' })); + jsonAnchor_1 = this.jsonAnchor == null ? document.createElement('a') : + this.jsonAnchor; + jsonAnchor_1.download = this.modelTopologyFileName; + jsonAnchor_1.href = modelTopologyAndWeightManifestURL; + // Trigger downloads by evoking a click event on the download anchors. + // When multiple downloads are started synchronously, Firefox will only + // save the last one. + return [4 /*yield*/, defer(function () { return jsonAnchor_1.dispatchEvent(new MouseEvent('click')); })]; + case 2: + // Trigger downloads by evoking a click event on the download anchors. + // When multiple downloads are started synchronously, Firefox will only + // save the last one. + _a.sent(); + if (!(modelArtifacts.weightData != null)) return [3 /*break*/, 4]; + weightDataAnchor_1 = this.weightDataAnchor == null ? + document.createElement('a') : + this.weightDataAnchor; + weightDataAnchor_1.download = this.weightDataFileName; + weightDataAnchor_1.href = weightsURL; + return [4 /*yield*/, defer(function () { return weightDataAnchor_1.dispatchEvent(new MouseEvent('click')); })]; + case 3: + _a.sent(); + _a.label = 4; + case 4: return [2 /*return*/, { modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts) }]; + } + }); + }); + }; + BrowserDownloads.URL_SCHEME = 'downloads://'; + return BrowserDownloads; + }()); + var BrowserFiles = /** @class */ (function () { + function BrowserFiles(files) { + if (files == null || files.length < 1) { + throw new Error("When calling browserFiles, at least 1 file is required, " + + ("but received " + files)); + } + this.files = files; + } + BrowserFiles.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + var jsonFile, weightFiles; + var _this = this; + return __generator(this, function (_a) { + jsonFile = this.files[0]; + weightFiles = this.files.slice(1); + return [2 /*return*/, new Promise(function (resolve, reject) { + var jsonReader = new FileReader(); + jsonReader.onload = function (event) { + // tslint:disable-next-line:no-any + var modelJSON = JSON.parse(event.target.result); + var modelTopology = modelJSON.modelTopology; + if (modelTopology == null) { + reject(new Error("modelTopology field is missing from file " + jsonFile.name)); + return; + } + if (weightFiles.length === 0) { + resolve({ modelTopology: modelTopology }); + } + var weightsManifest = modelJSON.weightsManifest; + if (weightsManifest == null) { + reject(new Error("weightManifest field is missing from file " + jsonFile.name)); + return; + } + var pathToFile; + try { + pathToFile = + _this.checkManifestAndWeightFiles(weightsManifest, weightFiles); + } + catch (err) { + reject(err); + return; + } + var weightSpecs = []; + var paths = []; + var perFileBuffers = []; + weightsManifest.forEach(function (weightsGroup) { + weightsGroup.paths.forEach(function (path) { + paths.push(path); + perFileBuffers.push(null); + }); + weightSpecs.push.apply(weightSpecs, weightsGroup.weights); + }); + weightsManifest.forEach(function (weightsGroup) { + weightsGroup.paths.forEach(function (path) { + var weightFileReader = new FileReader(); + weightFileReader.onload = function (event) { + // tslint:disable-next-line:no-any + var weightData = event.target.result; + var index = paths.indexOf(path); + perFileBuffers[index] = weightData; + if (perFileBuffers.indexOf(null) === -1) { + resolve({ + modelTopology: modelTopology, + weightSpecs: weightSpecs, + weightData: concatenateArrayBuffers(perFileBuffers), + format: modelJSON.format, + generatedBy: modelJSON.generatedBy, + convertedBy: modelJSON.convertedBy, + userDefinedMetadata: modelJSON.userDefinedMetadata + }); + } + }; + weightFileReader.onerror = function (error) { + return reject("Failed to weights data from file of path '" + path + "'."); + }; + weightFileReader.readAsArrayBuffer(pathToFile[path]); + }); + }); + }; + jsonReader.onerror = function (error) { return reject("Failed to read model topology and weights manifest JSON " + + ("from file '" + jsonFile.name + "'. BrowserFiles supports loading ") + + "Keras-style tf.Model artifacts only."); }; + jsonReader.readAsText(jsonFile); + })]; + }); + }); + }; + /** + * Check the compatibility between weights manifest and weight files. + */ + BrowserFiles.prototype.checkManifestAndWeightFiles = function (manifest, files) { + var basenames = []; + var fileNames = files.map(function (file) { return basename(file.name); }); + var pathToFile = {}; + for (var _i = 0, manifest_1 = manifest; _i < manifest_1.length; _i++) { + var group = manifest_1[_i]; + group.paths.forEach(function (path) { + var pathBasename = basename(path); + if (basenames.indexOf(pathBasename) !== -1) { + throw new Error("Duplicate file basename found in weights manifest: " + + ("'" + pathBasename + "'")); + } + basenames.push(pathBasename); + if (fileNames.indexOf(pathBasename) === -1) { + throw new Error("Weight file with basename '" + pathBasename + "' is not provided."); + } + else { + pathToFile[path] = files[fileNames.indexOf(pathBasename)]; + } + }); + } + if (basenames.length !== files.length) { + throw new Error("Mismatch in the number of files in weights manifest " + + ("(" + basenames.length + ") and the number of weight files provided ") + + ("(" + files.length + ").")); + } + return pathToFile; + }; + return BrowserFiles; + }()); + var browserDownloadsRouter = function (url) { + if (!env().getBool('IS_BROWSER')) { + return null; + } + else { + if (!Array.isArray(url) && url.startsWith(BrowserDownloads.URL_SCHEME)) { + return browserDownloads(url.slice(BrowserDownloads.URL_SCHEME.length)); + } + else { + return null; + } + } + }; + IORouterRegistry.registerSaveRouter(browserDownloadsRouter); + /** + * Creates an IOHandler that triggers file downloads from the browser. + * + * The returned `IOHandler` instance can be used as model exporting methods such + * as `tf.Model.save` and supports only saving. + * + * ```js + * const model = tf.sequential(); + * model.add(tf.layers.dense( + * {units: 1, inputShape: [10], activation: 'sigmoid'})); + * const saveResult = await model.save('downloads://mymodel'); + * // This will trigger downloading of two files: + * // 'mymodel.json' and 'mymodel.weights.bin'. + * console.log(saveResult); + * ``` + * + * @param fileNamePrefix Prefix name of the files to be downloaded. For use with + * `tf.Model`, `fileNamePrefix` should follow either of the following two + * formats: + * 1. `null` or `undefined`, in which case the default file + * names will be used: + * - 'model.json' for the JSON file containing the model topology and + * weights manifest. + * - 'model.weights.bin' for the binary file containing the binary weight + * values. + * 2. A single string or an Array of a single string, as the file name prefix. + * For example, if `'foo'` is provided, the downloaded JSON + * file and binary weights file will be named 'foo.json' and + * 'foo.weights.bin', respectively. + * @param config Additional configuration for triggering downloads. + * @returns An instance of `BrowserDownloads` `IOHandler`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Loading', + * namespace: 'io', + * ignoreCI: true + * } + */ + function browserDownloads(fileNamePrefix) { + if (fileNamePrefix === void 0) { fileNamePrefix = 'model'; } + return new BrowserDownloads(fileNamePrefix); + } + /** + * Creates an IOHandler that loads model artifacts from user-selected files. + * + * This method can be used for loading from files such as user-selected files + * in the browser. + * When used in conjunction with `tf.loadLayersModel`, an instance of + * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. + * + * ```js + * // Note: This code snippet won't run properly without the actual file input + * // elements in the HTML DOM. + * + * // Suppose there are two HTML file input (``) + * // elements. + * const uploadJSONInput = document.getElementById('upload-json'); + * const uploadWeightsInput = document.getElementById('upload-weights'); + * const model = await tf.loadLayersModel(tf.io.browserFiles( + * [uploadJSONInput.files[0], uploadWeightsInput.files[0]])); + * ``` + * + * @param files `File`s to load from. Currently, this function supports only + * loading from files that contain Keras-style models (i.e., `tf.Model`s), for + * which an `Array` of `File`s is expected (in that order): + * - A JSON file containing the model topology and weight manifest. + * - Optionally, One or more binary files containing the binary weights. + * These files must have names that match the paths in the `weightsManifest` + * contained by the aforementioned JSON file, or errors will be thrown + * during loading. These weights files have the same format as the ones + * generated by `tensorflowjs_converter` that comes with the `tensorflowjs` + * Python PIP package. If no weights files are provided, only the model + * topology will be loaded from the JSON file above. + * @returns An instance of `Files` `IOHandler`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Loading', + * namespace: 'io', + * ignoreCI: true + * } + */ + function browserFiles(files) { + return new BrowserFiles(files); + } + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Monitor Promise.all progress, fire onProgress callback function. + * + * @param promises Promise list going to be monitored + * @param onProgress Callback function. Fired when a promise resolved. + * @param startFraction Optional fraction start. Default to 0. + * @param endFraction Optional fraction end. Default to 1. + */ + function monitorPromisesProgress(promises, onProgress, startFraction, endFraction) { + checkPromises(promises); + startFraction = startFraction == null ? 0 : startFraction; + endFraction = endFraction == null ? 1 : endFraction; + checkFraction(startFraction, endFraction); + var resolvedPromise = 0; + var registerMonitor = function (promise) { + promise.then(function (value) { + var fraction = startFraction + + ++resolvedPromise / promises.length * (endFraction - startFraction); + // pass fraction as parameter to callback function. + onProgress(fraction); + return value; + }); + return promise; + }; + function checkPromises(promises) { + assert(promises != null && Array.isArray(promises) && promises.length > 0, function () { return 'promises must be a none empty array'; }); + } + function checkFraction(startFraction, endFraction) { + assert(startFraction >= 0 && startFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + + ("got startFraction " + startFraction); }); + assert(endFraction >= 0 && endFraction <= 1, function () { return "Progress fraction must be in range [0, 1], but " + + ("got endFraction " + endFraction); }); + assert(endFraction >= startFraction, function () { return "startFraction must be no more than endFraction, but " + + ("got startFraction " + startFraction + " and endFraction ") + + ("" + endFraction); }); + } + return Promise.all(promises.map(registerMonitor)); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Reads binary weights data from a number of URLs. + * + * @param fetchURLs URLs to send the HTTP requests at, using `fetch` calls. + * @param requestOptions RequestInit (options) for the HTTP requests. + * @param fetchFunc Optional overriding value for the `window.fetch` function. + * @param onProgress Optional, progress callback function, fired periodically + * before the load is completed. + * @returns A `Promise` of an Array of `ArrayBuffer`. The Array has the same + * length as `fetchURLs`. + */ + function loadWeightsAsArrayBuffer(fetchURLs, loadOptions) { + return __awaiter(this, void 0, void 0, function () { + var fetchFunc, requests, fetchStartFraction, fetchEndFraction, responses, _a, bufferPromises, bufferStartFraction, bufferEndFraction, buffers, _b; + return __generator(this, function (_c) { + switch (_c.label) { + case 0: + if (loadOptions == null) { + loadOptions = {}; + } + fetchFunc = loadOptions.fetchFunc == null ? env().platform.fetch : + loadOptions.fetchFunc; + requests = fetchURLs.map(function (fetchURL) { + return fetchFunc(fetchURL, loadOptions.requestInit, { isBinary: true }); + }); + fetchStartFraction = 0; + fetchEndFraction = 0.5; + if (!(loadOptions.onProgress == null)) return [3 /*break*/, 2]; + return [4 /*yield*/, Promise.all(requests)]; + case 1: + _a = _c.sent(); + return [3 /*break*/, 4]; + case 2: return [4 /*yield*/, monitorPromisesProgress(requests, loadOptions.onProgress, fetchStartFraction, fetchEndFraction)]; + case 3: + _a = _c.sent(); + _c.label = 4; + case 4: + responses = _a; + bufferPromises = responses.map(function (response) { return response.arrayBuffer(); }); + bufferStartFraction = 0.5; + bufferEndFraction = 1; + if (!(loadOptions.onProgress == null)) return [3 /*break*/, 6]; + return [4 /*yield*/, Promise.all(bufferPromises)]; + case 5: + _b = _c.sent(); + return [3 /*break*/, 8]; + case 6: return [4 /*yield*/, monitorPromisesProgress(bufferPromises, loadOptions.onProgress, bufferStartFraction, bufferEndFraction)]; + case 7: + _b = _c.sent(); + _c.label = 8; + case 8: + buffers = _b; + return [2 /*return*/, buffers]; + } + }); + }); + } + /** + * Reads a weights manifest JSON configuration, fetches the weights and + * returns them as `Tensor`s. + * + * @param manifest The weights manifest JSON. + * @param filePathPrefix The path prefix for filenames given in the manifest. + * Defaults to the empty string. + * @param weightNames The names of the weights to be fetched. + */ + function loadWeights(manifest, filePathPrefix, weightNames, requestInit) { + if (filePathPrefix === void 0) { filePathPrefix = ''; } + return __awaiter(this, void 0, void 0, function () { + var fetchWeights, loadWeights; + return __generator(this, function (_a) { + fetchWeights = function (fetchUrls) { + return loadWeightsAsArrayBuffer(fetchUrls, { requestInit: requestInit }); + }; + loadWeights = weightsLoaderFactory(fetchWeights); + return [2 /*return*/, loadWeights(manifest, filePathPrefix, weightNames)]; + }); + }); + } + /** + * Creates a function, which reads a weights manifest JSON configuration, + * fetches the weight files using the specified function and returns them as + * `Tensor`s. + * + * ```js + * // example for creating a nodejs weight loader, which reads the weight files + * // from disk using fs.readFileSync + * + * import * as fs from 'fs' + * + * const fetchWeightsFromDisk = (filePaths: string[]) => + * filePaths.map(filePath => fs.readFileSync(filePath).buffer) + * + * const loadWeights = tf.io.weightsLoaderFactory(fetchWeightsFromDisk) + * + * const manifest = JSON.parse( + * fs.readFileSync('./my_model-weights_manifest').toString() + * ) + * const weightMap = await loadWeights(manifest, './') + * ``` + * @param fetchWeightsFunction The function used for fetching the weight files. + * @returns Weight loading function. + */ + function weightsLoaderFactory(fetchWeightsFunction) { + var _this = this; + return function (manifest, filePathPrefix, weightNames) { + if (filePathPrefix === void 0) { filePathPrefix = ''; } + return __awaiter(_this, void 0, void 0, function () { + var groupIndicesToFetchMap, groupWeightsToFetch, weightsFound, allManifestWeightNames, weightsNotFound, groupIndicesToFetch, fetchUrls, buffers, weightsTensorMap, bufferIndexOffset; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + groupIndicesToFetchMap = manifest.map(function () { return false; }); + groupWeightsToFetch = {}; + weightsFound = weightNames != null ? weightNames.map(function () { return false; }) : []; + allManifestWeightNames = []; + manifest.forEach(function (manifestGroupConfig, groupIndex) { + var groupOffset = 0; + manifestGroupConfig.weights.forEach(function (weightsEntry) { + var rawDtype = ('quantization' in weightsEntry) ? + weightsEntry.quantization.dtype : + weightsEntry.dtype; + var weightsBytes = DTYPE_VALUE_SIZE_MAP[rawDtype] * + sizeFromShape(weightsEntry.shape); + var enqueueWeightsForFetchingFn = function () { + groupIndicesToFetchMap[groupIndex] = true; + if (groupWeightsToFetch[groupIndex] == null) { + groupWeightsToFetch[groupIndex] = []; + } + groupWeightsToFetch[groupIndex].push({ + manifestEntry: weightsEntry, + groupOffset: groupOffset, + sizeBytes: weightsBytes + }); + }; + if (weightNames != null) { + weightNames.forEach(function (weightName, weightIndex) { + if (weightName === weightsEntry.name) { + enqueueWeightsForFetchingFn(); + weightsFound[weightIndex] = true; + } + }); + } + else { + enqueueWeightsForFetchingFn(); + } + allManifestWeightNames.push(weightsEntry.name); + groupOffset += weightsBytes; + }); + }); + if (!weightsFound.every(function (found) { return found; })) { + weightsNotFound = weightNames.filter(function (_, i) { return !weightsFound[i]; }); + throw new Error("Could not find weights in manifest with names: " + + (weightsNotFound.join(', ') + ". \n") + + "Manifest JSON has weights with names: " + + (allManifestWeightNames.join(', ') + ".")); + } + groupIndicesToFetch = groupIndicesToFetchMap.reduce(function (accumulator, shouldFetch, i) { + if (shouldFetch) { + accumulator.push(i); + } + return accumulator; + }, []); + fetchUrls = []; + groupIndicesToFetch.forEach(function (i) { + manifest[i].paths.forEach(function (filepath) { + var fetchUrl = filePathPrefix + + (!filePathPrefix.endsWith('/') ? '/' : '') + filepath; + fetchUrls.push(fetchUrl); + }); + }); + return [4 /*yield*/, fetchWeightsFunction(fetchUrls)]; + case 1: + buffers = _a.sent(); + weightsTensorMap = {}; + bufferIndexOffset = 0; + groupIndicesToFetch.forEach(function (i) { + var numBuffers = manifest[i].paths.length; + var groupBytes = 0; + for (var i_1 = 0; i_1 < numBuffers; i_1++) { + groupBytes += buffers[bufferIndexOffset + i_1].byteLength; + } + // Create a buffer for the whole group. + var groupBuffer = new ArrayBuffer(groupBytes); + var groupByteBuffer = new Uint8Array(groupBuffer); + var groupBufferOffset = 0; + for (var i_2 = 0; i_2 < numBuffers; i_2++) { + var buffer = new Uint8Array(buffers[bufferIndexOffset + i_2]); + groupByteBuffer.set(buffer, groupBufferOffset); + groupBufferOffset += buffer.byteLength; + } + var weightsEntries = groupWeightsToFetch[i]; + weightsEntries.forEach(function (weightsEntry) { + var byteBuffer = groupBuffer.slice(weightsEntry.groupOffset, weightsEntry.groupOffset + weightsEntry.sizeBytes); + var nameToTensorMap = decodeWeights(byteBuffer, [weightsEntry.manifestEntry]); + for (var name_1 in nameToTensorMap) { + weightsTensorMap[name_1] = nameToTensorMap[name_1]; + } + }); + bufferIndexOffset += numBuffers; + }); + return [2 /*return*/, weightsTensorMap]; + } + }); + }); + }; + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var OCTET_STREAM_MIME_TYPE = 'application/octet-stream'; + var JSON_TYPE = 'application/json'; + var HTTPRequest = /** @class */ (function () { + function HTTPRequest(path, loadOptions) { + this.DEFAULT_METHOD = 'POST'; + if (loadOptions == null) { + loadOptions = {}; + } + this.weightPathPrefix = loadOptions.weightPathPrefix; + this.onProgress = loadOptions.onProgress; + if (loadOptions.fetchFunc != null) { + assert(typeof loadOptions.fetchFunc === 'function', function () { return 'Must pass a function that matches the signature of ' + + '`fetch` (see ' + + 'https://developer.mozilla.org/en-US/docs/Web/API/Fetch_API)'; }); + this.fetch = loadOptions.fetchFunc; + } + else { + this.fetch = env().platform.fetch; + } + assert(path != null && path.length > 0, function () { return 'URL path for http must not be null, undefined or ' + + 'empty.'; }); + if (Array.isArray(path)) { + assert(path.length === 2, function () { return 'URL paths for http must have a length of 2, ' + + ("(actual length is " + path.length + ")."); }); + } + this.path = path; + if (loadOptions.requestInit != null && + loadOptions.requestInit.body != null) { + throw new Error('requestInit is expected to have no pre-existing body, but has one.'); + } + this.requestInit = loadOptions.requestInit || {}; + } + HTTPRequest.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + var init, weightsManifest, modelTopologyAndWeightManifest, response; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: + if (modelArtifacts.modelTopology instanceof ArrayBuffer) { + throw new Error('BrowserHTTPRequest.save() does not support saving model topology ' + + 'in binary formats yet.'); + } + init = Object.assign({ method: this.DEFAULT_METHOD }, this.requestInit); + init.body = new FormData(); + weightsManifest = [{ + paths: ['./model.weights.bin'], + weights: modelArtifacts.weightSpecs, + }]; + modelTopologyAndWeightManifest = { + modelTopology: modelArtifacts.modelTopology, + format: modelArtifacts.format, + generatedBy: modelArtifacts.generatedBy, + convertedBy: modelArtifacts.convertedBy, + userDefinedMetadata: modelArtifacts.userDefinedMetadata, + weightsManifest: weightsManifest + }; + init.body.append('model.json', new Blob([JSON.stringify(modelTopologyAndWeightManifest)], { type: JSON_TYPE }), 'model.json'); + if (modelArtifacts.weightData != null) { + init.body.append('model.weights.bin', new Blob([modelArtifacts.weightData], { type: OCTET_STREAM_MIME_TYPE }), 'model.weights.bin'); + } + return [4 /*yield*/, this.fetch(this.path, init)]; + case 1: + response = _a.sent(); + if (response.ok) { + return [2 /*return*/, { + modelArtifactsInfo: getModelArtifactsInfoForJSON(modelArtifacts), + responses: [response], + }]; + } + else { + throw new Error("BrowserHTTPRequest.save() failed due to HTTP response status " + + (response.status + ".")); + } + return [2 /*return*/]; + } + }); + }); + }; + /** + * Load model artifacts via HTTP request(s). + * + * See the documentation to `tf.io.http` for details on the saved + * artifacts. + * + * @returns The loaded model artifacts (if loading succeeds). + */ + HTTPRequest.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + var modelConfigRequest, modelConfig, e_1, message, modelTopology, weightsManifest, generatedBy, convertedBy, format, userDefinedMetadata, weightSpecs, weightData, results; + return __generator(this, function (_a) { + switch (_a.label) { + case 0: return [4 /*yield*/, this.fetch(this.path, this.requestInit)]; + case 1: + modelConfigRequest = _a.sent(); + if (!modelConfigRequest.ok) { + throw new Error("Request to " + this.path + " failed with status code " + + (modelConfigRequest.status + ". Please verify this URL points to ") + + "the model JSON of the model to load."); + } + _a.label = 2; + case 2: + _a.trys.push([2, 4, , 5]); + return [4 /*yield*/, modelConfigRequest.json()]; + case 3: + modelConfig = _a.sent(); + return [3 /*break*/, 5]; + case 4: + e_1 = _a.sent(); + message = "Failed to parse model JSON of response from " + this.path + "."; + // TODO(nsthorat): Remove this after some time when we're comfortable that + // .pb files are mostly gone. + if (this.path.endsWith('.pb')) { + message += ' Your path contains a .pb file extension. ' + + 'Support for .pb models have been removed in TensorFlow.js 1.0 ' + + 'in favor of .json models. You can re-convert your Python ' + + 'TensorFlow model using the TensorFlow.js 1.0 conversion scripts ' + + 'or you can convert your.pb models with the \'pb2json\'' + + 'NPM script in the tensorflow/tfjs-converter repository.'; + } + else { + message += ' Please make sure the server is serving valid ' + + 'JSON for this request.'; + } + throw new Error(message); + case 5: + modelTopology = modelConfig.modelTopology; + weightsManifest = modelConfig.weightsManifest; + generatedBy = modelConfig.generatedBy; + convertedBy = modelConfig.convertedBy; + format = modelConfig.format; + userDefinedMetadata = modelConfig.userDefinedMetadata; + // We do not allow both modelTopology and weightsManifest to be missing. + if (modelTopology == null && weightsManifest == null) { + throw new Error("The JSON from HTTP path " + this.path + " contains neither model " + + "topology or manifest for weights."); + } + if (!(weightsManifest != null)) return [3 /*break*/, 7]; + return [4 /*yield*/, this.loadWeights(weightsManifest)]; + case 6: + results = _a.sent(); + weightSpecs = results[0], weightData = results[1]; + _a.label = 7; + case 7: return [2 /*return*/, { + modelTopology: modelTopology, + weightSpecs: weightSpecs, + weightData: weightData, + userDefinedMetadata: userDefinedMetadata, + generatedBy: generatedBy, + convertedBy: convertedBy, + format: format + }]; + } + }); + }); + }; + HTTPRequest.prototype.loadWeights = function (weightsManifest) { + return __awaiter(this, void 0, void 0, function () { + var weightPath, _a, prefix, suffix, pathPrefix, weightSpecs, _i, weightsManifest_1, entry, fetchURLs, buffers; + return __generator(this, function (_b) { + switch (_b.label) { + case 0: + weightPath = Array.isArray(this.path) ? this.path[1] : this.path; + _a = parseUrl(weightPath), prefix = _a[0], suffix = _a[1]; + pathPrefix = this.weightPathPrefix || prefix; + weightSpecs = []; + for (_i = 0, weightsManifest_1 = weightsManifest; _i < weightsManifest_1.length; _i++) { + entry = weightsManifest_1[_i]; + weightSpecs.push.apply(weightSpecs, entry.weights); + } + fetchURLs = []; + weightsManifest.forEach(function (weightsGroup) { + weightsGroup.paths.forEach(function (path) { + fetchURLs.push(pathPrefix + path + suffix); + }); + }); + return [4 /*yield*/, loadWeightsAsArrayBuffer(fetchURLs, { + requestInit: this.requestInit, + fetchFunc: this.fetch, + onProgress: this.onProgress + })]; + case 1: + buffers = _b.sent(); + return [2 /*return*/, [weightSpecs, concatenateArrayBuffers(buffers)]]; + } + }); + }); + }; + HTTPRequest.URL_SCHEME_REGEX = /^https?:\/\//; + return HTTPRequest; + }()); + /** + * Extract the prefix and suffix of the url, where the prefix is the path before + * the last file, and suffix is the search params after the last file. + * ``` + * const url = 'http://tfhub.dev/model/1/tensorflowjs_model.pb?tfjs-format=file' + * [prefix, suffix] = parseUrl(url) + * // prefix = 'http://tfhub.dev/model/1/' + * // suffix = '?tfjs-format=file' + * ``` + * @param url the model url to be parsed. + */ + function parseUrl(url) { + var lastSlash = url.lastIndexOf('/'); + var lastSearchParam = url.lastIndexOf('?'); + var prefix = url.substring(0, lastSlash); + var suffix = lastSearchParam > lastSlash ? url.substring(lastSearchParam) : ''; + return [prefix + '/', suffix]; + } + function isHTTPScheme(url) { + return url.match(HTTPRequest.URL_SCHEME_REGEX) != null; + } + var httpRouter = function (url, onProgress) { + if (typeof fetch === 'undefined') { + // `http` uses `fetch` or `node-fetch`, if one wants to use it in + // an environment that is not the browser or node they have to setup a + // global fetch polyfill. + return null; + } + else { + var isHTTP = true; + if (Array.isArray(url)) { + isHTTP = url.every(function (urlItem) { return isHTTPScheme(urlItem); }); + } + else { + isHTTP = isHTTPScheme(url); + } + if (isHTTP) { + return http(url, { onProgress: onProgress }); + } + } + return null; + }; + IORouterRegistry.registerSaveRouter(httpRouter); + IORouterRegistry.registerLoadRouter(httpRouter); + /** + * Creates an IOHandler subtype that sends model artifacts to HTTP server. + * + * An HTTP request of the `multipart/form-data` mime type will be sent to the + * `path` URL. The form data includes artifacts that represent the topology + * and/or weights of the model. In the case of Keras-style `tf.Model`, two + * blobs (files) exist in form-data: + * - A JSON file consisting of `modelTopology` and `weightsManifest`. + * - A binary weights file consisting of the concatenated weight values. + * These files are in the same format as the one generated by + * [tfjs_converter](https://js.tensorflow.org/tutorials/import-keras.html). + * + * The following code snippet exemplifies the client-side code that uses this + * function: + * + * ```js + * const model = tf.sequential(); + * model.add( + * tf.layers.dense({units: 1, inputShape: [100], activation: 'sigmoid'})); + * + * const saveResult = await model.save(tf.io.http( + * 'http://model-server:5000/upload', {method: 'PUT'})); + * console.log(saveResult); + * ``` + * + * If the default `POST` method is to be used, without any custom parameters + * such as headers, you can simply pass an HTTP or HTTPS URL to `model.save`: + * + * ```js + * const saveResult = await model.save('http://model-server:5000/upload'); + * ``` + * + * The following GitHub Gist + * https://gist.github.com/dsmilkov/1b6046fd6132d7408d5257b0976f7864 + * implements a server based on [flask](https://github.com/pallets/flask) that + * can receive the request. Upon receiving the model artifacts via the requst, + * this particular server reconsistutes instances of [Keras + * Models](https://keras.io/models/model/) in memory. + * + * + * @param path A URL path to the model. + * Can be an absolute HTTP path (e.g., + * 'http://localhost:8000/model-upload)') or a relative path (e.g., + * './model-upload'). + * @param requestInit Request configurations to be used when sending + * HTTP request to server using `fetch`. It can contain fields such as + * `method`, `credentials`, `headers`, `mode`, etc. See + * https://developer.mozilla.org/en-US/docs/Web/API/Request/Request + * for more information. `requestInit` must not have a body, because the + * body will be set by TensorFlow.js. File blobs representing the model + * topology (filename: 'model.json') and the weights of the model (filename: + * 'model.weights.bin') will be appended to the body. If `requestInit` has a + * `body`, an Error will be thrown. + * @param loadOptions Optional configuration for the loading. It includes the + * following fields: + * - weightPathPrefix Optional, this specifies the path prefix for weight + * files, by default this is calculated from the path param. + * - fetchFunc Optional, custom `fetch` function. E.g., in Node.js, + * the `fetch` from node-fetch can be used here. + * - onProgress Optional, progress callback function, fired periodically + * before the load is completed. + * @returns An instance of `IOHandler`. + */ + /** + * @doc { + * heading: 'Models', + * subheading: 'Loading', + * namespace: 'io', + * ignoreCI: true + * } + */ + function http(path, loadOptions) { + return new HTTPRequest(path, loadOptions); + } + /** + * Deprecated. Use `tf.io.http`. + * @param path + * @param loadOptions + */ + function browserHTTPRequest(path, loadOptions) { + return http(path, loadOptions); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + var PassthroughLoader = /** @class */ (function () { + function PassthroughLoader(modelArtifacts) { + this.modelArtifacts = modelArtifacts; + } + PassthroughLoader.prototype.load = function () { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.modelArtifacts]; + }); + }); + }; + return PassthroughLoader; + }()); + var PassthroughSaver = /** @class */ (function () { + function PassthroughSaver(saveHandler) { + this.saveHandler = saveHandler; + } + PassthroughSaver.prototype.save = function (modelArtifacts) { + return __awaiter(this, void 0, void 0, function () { + return __generator(this, function (_a) { + return [2 /*return*/, this.saveHandler(modelArtifacts)]; + }); + }); + }; + return PassthroughSaver; + }()); + /** + * Creates an IOHandler that loads model artifacts from memory. + * + * When used in conjunction with `tf.loadLayersModel`, an instance of + * `tf.LayersModel` (Keras-style) can be constructed from the loaded artifacts. + * + * ```js + * const model = await tf.loadLayersModel(tf.io.fromMemory( + * modelTopology, weightSpecs, weightData)); + * ``` + * + * @param modelArtifacts a object containing model topology (i.e., parsed from + * the JSON format). + * @param weightSpecs An array of `WeightsManifestEntry` objects describing the + * names, shapes, types, and quantization of the weight data. + * @param weightData A single `ArrayBuffer` containing the weight data, + * concatenated in the order described by the weightSpecs. + * @param trainingConfig Model training configuration. Optional. + * + * @returns A passthrough `IOHandler` that simply loads the provided data. + */ + function fromMemory(modelArtifacts, weightSpecs, weightData, trainingConfig) { + if (arguments.length === 1) { + var isModelArtifacts = modelArtifacts.modelTopology != null || + modelArtifacts.weightSpecs != null; + if (isModelArtifacts) { + return new PassthroughLoader(modelArtifacts); + } + else { + // Legacy support: with only modelTopology. + // TODO(cais): Remove this deprecated API. + console.warn('Please call tf.io.fromMemory() with only one argument. ' + + 'The argument should be of type ModelArtifacts. ' + + 'The multi-argument signature of tf.io.fromMemory() has been ' + + 'deprecated and will be removed in a future release.'); + return new PassthroughLoader({ modelTopology: modelArtifacts }); + } + } + else { + // Legacy support. + // TODO(cais): Remove this deprecated API. + console.warn('Please call tf.io.fromMemory() with only one argument. ' + + 'The argument should be of type ModelArtifacts. ' + + 'The multi-argument signature of tf.io.fromMemory() has been ' + + 'deprecated and will be removed in a future release.'); + return new PassthroughLoader({ + modelTopology: modelArtifacts, + weightSpecs: weightSpecs, + weightData: weightData, + trainingConfig: trainingConfig + }); + } + } + /** + * Creates an IOHandler that passes saved model artifacts to a callback. + * + * ```js + * function handleSave(artifacts) { + * // ... do something with the artifacts ... + * return {modelArtifactsInfo: {...}, ...}; + * } + * + * const saveResult = model.save(tf.io.withSaveHandler(handleSave)); + * ``` + * + * @param saveHandler A function that accepts a `ModelArtifacts` and returns a + * `SaveResult`. + */ + function withSaveHandler(saveHandler) { + return new PassthroughSaver(saveHandler); + } + + /** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + var io = /*#__PURE__*/Object.freeze({ + browserFiles: browserFiles, + browserHTTPRequest: browserHTTPRequest, + concatenateArrayBuffers: concatenateArrayBuffers, + decodeWeights: decodeWeights, + encodeWeights: encodeWeights, + fromMemory: fromMemory, + getLoadHandlers: getLoadHandlers, + getModelArtifactsInfoForJSON: getModelArtifactsInfoForJSON, + getSaveHandlers: getSaveHandlers, + http: http, + isHTTPScheme: isHTTPScheme, + loadWeights: loadWeights, + registerLoadRouter: registerLoadRouter, + registerSaveRouter: registerSaveRouter, + weightsLoaderFactory: weightsLoaderFactory, + withSaveHandler: withSaveHandler, + copyModel: copyModel, + listModels: listModels, + moveModel: moveModel, + removeModel: removeModel + }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Computes the confusion matrix from true labels and predicted labels. + * + * ```js + * const labels = tf.tensor1d([0, 1, 2, 1, 0], 'int32'); + * const predictions = tf.tensor1d([0, 2, 2, 1, 0], 'int32'); + * const numClasses = 3; + * const out = tf.math.confusionMatrix(labels, predictions, numClasses); + * out.print(); + * // Expected output matrix: + * // [[2, 0, 0], + * // [0, 1, 1], + * // [0, 0, 1]] + * ``` + * + * @param labels The target labels, assumed to be 0-based integers + * for the classes. The shape is `[numExamples]`, where + * `numExamples` is the number of examples included. + * @param predictions The predicted classes, assumed to be + * 0-based integers for the classes. Must have the same shape as `labels`. + * @param numClasses Number of all classes, as an integer. + * Its value must be larger than the largest element in `labels` and + * `predictions`. + * @returns The confusion matrix as a int32-type 2D tensor. The value at + * row `r` and column `c` is the number of times examples of actual class + * `r` were predicted as class `c`. + */ + /** @doc {heading: 'Operations', subheading: 'Evaluation'} */ + function confusionMatrix_(labels, predictions, numClasses) { + var $labels = convertToTensor(labels, 'labels', 'confusionMatrix'); + var $predictions = convertToTensor(predictions, 'predictions', 'confusionMatrix'); + assert(numClasses == null || numClasses > 0 && Number.isInteger(numClasses), function () { return "If provided, numClasses must be a positive integer, " + + ("but got " + numClasses); }); + assert($labels.rank === 1, function () { return "Expected the rank of labels to be 1, but got " + $labels.rank; }); + assert($predictions.rank === 1, function () { return "Expected the rank of predictions to be 1, " + + ("but got " + $predictions.rank); }); + assert($labels.shape[0] === $predictions.shape[0], function () { return "Mismatch in the number of examples: " + + ($labels.shape[0] + " vs. " + $predictions.shape[0] + ". ") + + "Labels and predictions should have the same number of elements."; }); + assert(numClasses > 0 && Number.isInteger(numClasses), function () { return "numClasses is required to be a positive integer, but got " + + ("" + numClasses); }); + // TODO(cais): In the future, if oneHot supports tensors inputs for + // `numClasses`, `confusionMatrix` can make `numClasses` optional. + var oneHotLabels = oneHot($labels.asType('int32'), numClasses); + var oneHotPredictions = oneHot($predictions.asType('int32'), numClasses); + return oneHotLabels.transpose().matMul(oneHotPredictions).asType('int32'); + } + var confusionMatrix = op({ confusionMatrix_: confusionMatrix_ }); + + /** + * @license + * Copyright 2018 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + + var math = /*#__PURE__*/Object.freeze({ + confusionMatrix: confusionMatrix + }); + + /** + * @license + * Copyright 2019 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + /** + * Creates a `tf.Tensor` from an image. + * + * ```js + * const image = new ImageData(1, 1); + * image.data[0] = 100; + * image.data[1] = 150; + * image.data[2] = 200; + * image.data[3] = 255; + * + * tf.browser.fromPixels(image).print(); + * ``` + * + * @param pixels The input image to construct the tensor from. The + * supported image types are all 4-channel. You can also pass in an image + * object with following attributes: + * `{data: Uint8Array; width: number; height: number}` + * @param numChannels The number of channels of the output tensor. A + * numChannels value less than 4 allows you to ignore channels. Defaults to + * 3 (ignores alpha channel of input image). + */ + /** @doc {heading: 'Browser', namespace: 'browser', ignoreCI: true} */ + function fromPixels_(pixels, numChannels) { + if (numChannels === void 0) { numChannels = 3; } + if (numChannels > 4) { + throw new Error('Cannot construct Tensor with more than 4 channels from pixels.'); + } + var isVideo = typeof (HTMLVideoElement) !== 'undefined' && + pixels instanceof HTMLVideoElement; + if (isVideo) { + var HAVE_CURRENT_DATA_READY_STATE = 2; + if (isVideo && + pixels.readyState < + HAVE_CURRENT_DATA_READY_STATE) { + throw new Error('The video element has not loaded data yet. Please wait for ' + + '`loadeddata` event on the