From 872ba97d54e151d8dd349ceeb02d12198ed3bbaa Mon Sep 17 00:00:00 2001 From: Na Li Date: Thu, 28 May 2020 14:16:18 -0700 Subject: [PATCH 01/14] Resolve merge conflict. --- .gitignore | 3 +++ e2e/scripts/test-ci.sh | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/.gitignore b/.gitignore index e5ac83d4aec..876d2da5b8a 100644 --- a/.gitignore +++ b/.gitignore @@ -28,8 +28,11 @@ tfjs-layers/integration_tests/tfjs2keras/test-data/ tfjs-layers/integration/typescript/yarn.lock e2e/integration_tests/create_save_predict_data e2e/integration_tests/convert_predict_data +<<<<<<< HEAD tfjs-converter/python/tensorflowjs.egg-info tfjs-inference/binaries +======= +>>>>>>> 76944b87... Move converter validation to e2e. # Ignore the src, binding, scripts of tfjs-node-gpu since it is copied over when # building. diff --git a/e2e/scripts/test-ci.sh b/e2e/scripts/test-ci.sh index 897b61fee2e..0b4a89a1b32 100755 --- a/e2e/scripts/test-ci.sh +++ b/e2e/scripts/test-ci.sh @@ -37,6 +37,11 @@ if [[ "$TAGS" == *"#REGRESSION"* ]]; then python create_save_predict.py echo "Create saved models and convert." +<<<<<<< HEAD +======= + rm -rf "convert_predict_data/" + mkdir "convert_predict_data/" +>>>>>>> 76944b87... Move converter validation to e2e. python convert_predict.py # Cleanup python env. From afdb3650dbcd673f405b10a0b295026c7fac64a0 Mon Sep 17 00:00:00 2001 From: Na Li Date: Sat, 13 Jun 2020 14:49:23 -0700 Subject: [PATCH 02/14] Resolve merge conflict. --- e2e/scripts/test-ci.sh | 5 ----- 1 file changed, 5 deletions(-) diff --git a/e2e/scripts/test-ci.sh b/e2e/scripts/test-ci.sh index 0b4a89a1b32..897b61fee2e 100755 --- a/e2e/scripts/test-ci.sh +++ b/e2e/scripts/test-ci.sh @@ -37,11 +37,6 @@ if [[ "$TAGS" == *"#REGRESSION"* ]]; then python create_save_predict.py echo "Create saved models and convert." -<<<<<<< HEAD -======= - rm -rf "convert_predict_data/" - mkdir "convert_predict_data/" ->>>>>>> 76944b87... Move converter validation to e2e. python convert_predict.py # Cleanup python env. From 13e827f77a4e0bfe8baa9d45b58835c5945ce025 Mon Sep 17 00:00:00 2001 From: Na Li Date: Sat, 13 Jun 2020 14:31:18 -0700 Subject: [PATCH 03/14] Implement dilation2d op. --- tfjs-core/src/engine.ts | 1 + tfjs-core/src/kernel_names.ts | 9 +++ tfjs-core/src/ops/dilation2d.ts | 105 +++++++++++++++++++++++++++ tfjs-core/src/ops/dilation2d_test.ts | 38 ++++++++++ tfjs-core/src/ops/ops.ts | 1 + tfjs-core/src/tests.ts | 1 + 6 files changed, 155 insertions(+) create mode 100644 tfjs-core/src/ops/dilation2d.ts create mode 100644 tfjs-core/src/ops/dilation2d_test.ts diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index 4ee48fa5a10..a92e5ff8b29 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -555,6 +555,7 @@ export class Engine implements TensorTracker, DataMover { let kernelFunc: () => Tensor[]; const kernel = getKernel(kernelName, this.backendName); + let out: TensorInfo|TensorInfo[]; if (kernel != null) { kernelFunc = () => { diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index b8c5f3bd845..c00a46a136c 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -192,6 +192,15 @@ export type DepthwiseConv2dNativeBackpropInputInputs = export const Diag = 'Diag'; export type DiagInputs = Pick; +export const Dilation2D = 'Dilation2D'; +export type Dilation2DInputs = Pick; +export interface Dilation2DAttrs { + strides: [number, number]|number; + pad: 'valid'|'same'|number; + dataFormat: 'NHWC'; + dilations: [number, number]|number; +} + export const Div = 'Div'; export type DivInputs = BinaryInputs; diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts new file mode 100644 index 00000000000..eb24eb36c21 --- /dev/null +++ b/tfjs-core/src/ops/dilation2d.ts @@ -0,0 +1,105 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Dilation2D, Dilation2DAttrs, Dilation2DInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {reshape} from './array_ops'; +import {op} from './operation'; + + + +/** + * Computes the grayscale dilation over the input `x`. + * + * @param x The input tensor, rank 3 or rank 4 of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param filter The filter tensor, rank 3, of shape + * `[filterHeight, filterWidth, depth]`. + * @param strides The strides of the sliding window for each dimension of the + * input tensor: `[strideHeight, strideWidth]`. + * If `strides` is a single number, + * then `strideHeight == strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1*1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat Specify the data format of the input and output data. + * Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations` + * is a single number, then `dilationHeight == dilationWidth`. If it is + * greater than 1, then all values of `strides` must be 1. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function dilation2d_( + x: T|TensorLike, filter: Tensor3D|TensorLike, + strides: [number, number]|number, pad: 'valid'|'same'|number, + dataFormat: 'NHWC' = 'NHWC', + dilations: [number, number]|number = [1, 1]): T { + const $x = convertToTensor(x, 'x', 'dilation2d'); + const $filter = convertToTensor(filter, 'filter', 'dilation2d'); + + let x4D = $x as Tensor4D; + let reshapedTo4D = false; + + if ($x.rank === 3) { + x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]); + reshapedTo4D = true; + } + + util.assert( + x4D.rank === 4, + () => `Error in dilation2d: input must be rank 4, but got rank ` + + `${x4D.rank}.`); + util.assert( + $filter.rank === 3, + () => `Error in dilation2d: filter must be rank 3, but got rank ` + + `${$filter.rank}.`); + util.assert( + dataFormat === 'NHWC', + () => `Error in dilation2d: Only NHWC is currently supported, ` + + `but got dataFormat of ${dataFormat}`); + + const inputs: Dilation2DInputs = {x: $x, filter: $filter}; + const attrs: Dilation2DAttrs = {strides, pad, dataFormat, dilations}; + + const res = ENGINE.runKernel( + Dilation2D, inputs as {} as NamedTensorMap, + attrs as {} as NamedAttrMap) as T; + + if (reshapedTo4D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]) as T; + } + + return res; +} + +export const dilation2d = op({dilation2d_}); diff --git a/tfjs-core/src/ops/dilation2d_test.ts b/tfjs-core/src/ops/dilation2d_test.ts new file mode 100644 index 00000000000..ddad7c1eb37 --- /dev/null +++ b/tfjs-core/src/ops/dilation2d_test.ts @@ -0,0 +1,38 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('dilation2d', ALL_ENVS, () => { + it('Should throw error.', async () => { + const inputDepth = 1; + const inputShape: [number, number, number] = [2, 2, inputDepth]; + const filterShape: [number, number, number] = [1, 1, inputDepth]; + + const pad = 'same'; + const stride: [number, number] = [1, 1]; + + const x = tf.tensor3d([1, 1, 1, 1], inputShape); + const filter = tf.tensor3d([1], filterShape); + + const result = tf.dilation2d(x, filter, stride, pad); + + expectArraysClose(await result.data(), [1, 1, 1, 1]); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index ab66c92d783..490d1d9ae94 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -43,6 +43,7 @@ export {cumsum} from './cumsum'; export {depthToSpace} from './depth_to_space'; export {depthwiseConv2d} from './depthwise_conv2d'; export {diag} from './diag'; +export {dilation2d} from './dilation2d'; export {div} from './div'; export {divNoNan} from './div_no_nan'; export {dot} from './dot'; diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index f144ae92357..6a6492563d7 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -69,6 +69,7 @@ import './ops/conv_util_test'; import './ops/cumsum_test'; import './ops/depth_to_space_test'; import './ops/diag_test'; +import './ops/dilation2d_test'; import './ops/dropout_test'; import './ops/dropout_util_test'; import './ops/elu_test'; From a74562d2644c1e4ddac47b029de7ea5c4bd1698d Mon Sep 17 00:00:00 2001 From: Na Li Date: Sat, 13 Jun 2020 14:53:02 -0700 Subject: [PATCH 04/14] . --- .gitignore | 3 --- tfjs-core/src/engine.ts | 1 - 2 files changed, 4 deletions(-) diff --git a/.gitignore b/.gitignore index 876d2da5b8a..e5ac83d4aec 100644 --- a/.gitignore +++ b/.gitignore @@ -28,11 +28,8 @@ tfjs-layers/integration_tests/tfjs2keras/test-data/ tfjs-layers/integration/typescript/yarn.lock e2e/integration_tests/create_save_predict_data e2e/integration_tests/convert_predict_data -<<<<<<< HEAD tfjs-converter/python/tensorflowjs.egg-info tfjs-inference/binaries -======= ->>>>>>> 76944b87... Move converter validation to e2e. # Ignore the src, binding, scripts of tfjs-node-gpu since it is copied over when # building. diff --git a/tfjs-core/src/engine.ts b/tfjs-core/src/engine.ts index a92e5ff8b29..4ee48fa5a10 100644 --- a/tfjs-core/src/engine.ts +++ b/tfjs-core/src/engine.ts @@ -555,7 +555,6 @@ export class Engine implements TensorTracker, DataMover { let kernelFunc: () => Tensor[]; const kernel = getKernel(kernelName, this.backendName); - let out: TensorInfo|TensorInfo[]; if (kernel != null) { kernelFunc = () => { From 41a46e154e542af8c815a667dd840cb3dc910657 Mon Sep 17 00:00:00 2001 From: Na Li Date: Fri, 19 Jun 2020 17:46:20 -0700 Subject: [PATCH 05/14] Resolve merge conflict. --- tfjs-backend-cpu/run_tests.ts | 3 +- tfjs-backend-cpu/src/kernels/Dilation2D.ts | 99 ++++++ .../src/kernels/Dilation2DBackpropFilter.ts | 117 ++++++++ .../src/kernels/Dilation2DBackpropInput.ts | 117 ++++++++ tfjs-backend-cpu/src/register_all_kernels.ts | 9 +- tfjs-backend-webgl/src/setup_test.ts | 5 + tfjs-converter/docs/supported_ops.md | 1 + .../tensorflowjs/op_list/convolution.json | 33 ++ .../executors/convolution_executor.ts | 23 ++ .../executors/convolution_executor_test.ts | 22 ++ .../src/operations/op_list/convolution.ts | 19 +- tfjs-core/src/gradients/Dilation2D_grad.ts | 42 +++ tfjs-core/src/jasmine_util.ts | 21 +- tfjs-core/src/kernel_names.ts | 8 + tfjs-core/src/ops/conv_util.ts | 45 ++- tfjs-core/src/ops/dilation2d.ts | 10 +- tfjs-core/src/ops/dilation2d_test.ts | 284 +++++++++++++++++- .../src/public/chained_ops/dilation2d.ts | 36 +++ .../chained_ops/register_all_chained_ops.ts | 1 + .../register_all_chained_ops_test.ts | 1 + tfjs-core/src/register_all_gradients.ts | 2 + tfjs-core/src/util.ts | 21 +- tfjs-node/src/kernels/Dilation2D.ts | 54 ++++ .../src/kernels/Dilation2DBackpropFilter.ts | 55 ++++ .../src/kernels/Dilation2DBackpropInput.ts | 55 ++++ 25 files changed, 1056 insertions(+), 27 deletions(-) create mode 100644 tfjs-backend-cpu/src/kernels/Dilation2D.ts create mode 100644 tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts create mode 100644 tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts create mode 100644 tfjs-core/src/gradients/Dilation2D_grad.ts create mode 100644 tfjs-core/src/public/chained_ops/dilation2d.ts create mode 100644 tfjs-node/src/kernels/Dilation2D.ts create mode 100644 tfjs-node/src/kernels/Dilation2DBackpropFilter.ts create mode 100644 tfjs-node/src/kernels/Dilation2DBackpropInput.ts diff --git a/tfjs-backend-cpu/run_tests.ts b/tfjs-backend-cpu/run_tests.ts index 822f8f514b1..7152c2e5874 100644 --- a/tfjs-backend-cpu/run_tests.ts +++ b/tfjs-backend-cpu/run_tests.ts @@ -38,14 +38,13 @@ const cpuTests = 'src/**/*_test.ts'; const runner = new jasmineCtor(); runner.loadConfig({spec_files: [cpuTests, coreTests], random: false}); -// customInclude takes higher priority then TEST_FILTERS, only when -// customInclude return false will TEST_FILTERS be considered. const TEST_FILTERS: TestFilter[] = []; const customInclude = (testName: string) => { // Exclude webworker test if (testName.includes('computation in worker')) { return false; } + // Include all other tests. return true; }; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2D.ts b/tfjs-backend-cpu/src/kernels/Dilation2D.ts new file mode 100644 index 00000000000..747e03edab7 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Dilation2D.ts @@ -0,0 +1,99 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, TypedArray, util} from '@tensorflow/tfjs-core'; +import {KernelConfig} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export const dilation2dConfig: KernelConfig = { + kernelName: Dilation2D, + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter} = inputs as Dilation2DInputs; + const {strides, pad, dataFormat, dilations} = + attrs as {} as Dilation2DAttrs; + const cpuBackend = backend as MathBackendCPU; + + const $x = + util.toNestedArray( + x.shape, cpuBackend.data.get(x.dataId).values as TypedArray) as + number[][][][]; + + const $filter = util.toNestedArray( + filter.shape, + cpuBackend.data.get(filter.dataId).values as + TypedArray) as number[][][]; + + const { + batchSize, + inHeight, + inWidth, + inChannels, + outHeight, + outWidth, + padInfo, + strideHeight, + strideWidth, + filterHeight, + filterWidth, + dilationHeight, + dilationWidth, + outShape + } = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, dataFormat, + dilations); + + const output = + util.makeZerosNestedTypedArray(outShape, x.dtype) as number[][][][]; + + for (let b = 0; b < batchSize; ++b) { + for (let hOut = 0; hOut < outHeight; ++hOut) { + const hBeg = hOut * strideHeight - padInfo.top; + for (let wOut = 0; wOut < outWidth; ++wOut) { + const wBeg = wOut * strideWidth - padInfo.left; + for (let d = 0; d < inChannels; ++d) { + let curVal = Number.MIN_SAFE_INTEGER; + for (let h = 0; h < filterHeight; ++h) { + const hIn = hBeg + h * dilationHeight; + if (hIn >= 0 && hIn < inHeight) { + for (let w = 0; w < filterWidth; ++w) { + const wIn = wBeg + w * dilationWidth; + if (wIn >= 0 && wIn < inWidth) { + const val = $x[b][hIn][wIn][d] + $filter[h][w][d]; + if (val > curVal) { + curVal = val; + } + } + } + } + } + output[b][hOut][wOut][d] = curVal; + } + } + } + } + + const dataId = cpuBackend.write( + util.toTypedArray(output, x.dtype, false /* debug mode */), outShape, + x.dtype); + + return {dataId, shape: outShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts new file mode 100644 index 00000000000..d08d3555b96 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts @@ -0,0 +1,117 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropFilter, Tensor3D, Tensor4D, TypedArray, util} from '@tensorflow/tfjs-core'; +import {KernelConfig} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export const dilation2dBackpropFilterConfig: KernelConfig = { + kernelName: Dilation2DBackpropFilter, + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = + inputs as {x: Tensor4D, filter: Tensor3D, dy: Tensor4D}; + const {strides, pad, dataFormat, dilations} = + attrs as {} as Dilation2DAttrs; + const cpuBackend = backend as MathBackendCPU; + + const $x = + util.toNestedArray( + x.shape, cpuBackend.data.get(x.dataId).values as TypedArray) as + number[][][][]; + + const $filter = util.toNestedArray( + filter.shape, + cpuBackend.data.get(filter.dataId).values as + TypedArray) as number[][][]; + + const { + batchSize, + inHeight, + inWidth, + inChannels, + outHeight, + outWidth, + padInfo, + strideHeight, + strideWidth, + filterHeight, + filterWidth, + dilationHeight, + dilationWidth, + outShape + } = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, dataFormat, + dilations); + + util.assert( + dy.rank === outShape.length, + () => `Error in ${Dilation2DBackpropFilter}, dy ` + + `must have the same rank as output ${outShape.length}, but got ` + + `${dy.rank}`); + + const $dy = + util.toNestedArray( + outShape, cpuBackend.data.get(dy.dataId).values as TypedArray) as + number[][][][]; + + // The computed filter gradients has the same dimensions as the filter: + // [filterHeight, filterWidth, depth] + const gradients = util.makeZerosNestedTypedArray( + filter.shape, filter.dtype) as number[][][]; + + for (let b = 0; b < batchSize; ++b) { + for (let hOut = 0; hOut < outHeight; ++hOut) { + const hBeg = hOut * strideHeight - padInfo.top; + for (let wOut = 0; wOut < outWidth; ++wOut) { + const wBeg = wOut * strideWidth - padInfo.left; + for (let d = 0; d < inChannels; ++d) { + let curVal = Number.MIN_SAFE_INTEGER; + let hMax = 0; + let wMax = 0; + for (let h = 0; h < filterHeight; ++h) { + const hIn = hBeg + h * dilationHeight; + if (hIn >= 0 && hIn < inHeight) { + for (let w = 0; w < filterWidth; ++w) { + const wIn = wBeg + w * dilationWidth; + if (wIn >= 0 && wIn < inWidth) { + const val = $x[b][hIn][wIn][d] + $filter[h][w][d]; + if (val > curVal) { + curVal = val; + hMax = h; + wMax = w; + } + } + } + } + } + gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d]; + } + } + } + } + + const dataId = cpuBackend.write( + util.toTypedArray(gradients, x.dtype, false /* debug mode */), + filter.shape, filter.dtype); + + return {dataId, shape: filter.shape, dtype: filter.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts new file mode 100644 index 00000000000..f4c31a199e7 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts @@ -0,0 +1,117 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropInput, Tensor3D, Tensor4D, TypedArray, util} from '@tensorflow/tfjs-core'; +import {KernelConfig} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export const dilation2dBackpropInputConfig: KernelConfig = { + kernelName: Dilation2DBackpropInput, + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = + inputs as {x: Tensor4D, filter: Tensor3D, dy: Tensor4D}; + const {strides, pad, dataFormat, dilations} = + attrs as {} as Dilation2DAttrs; + const cpuBackend = backend as MathBackendCPU; + + const $x = + util.toNestedArray( + x.shape, cpuBackend.data.get(x.dataId).values as TypedArray) as + number[][][][]; + + const $filter = util.toNestedArray( + filter.shape, + cpuBackend.data.get(filter.dataId).values as + TypedArray) as number[][][]; + + const { + batchSize, + inHeight, + inWidth, + inChannels, + outHeight, + outWidth, + padInfo, + strideHeight, + strideWidth, + filterHeight, + filterWidth, + dilationHeight, + dilationWidth, + outShape + } = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, dataFormat, + dilations); + + util.assert( + dy.rank === outShape.length, + () => `Error in ${Dilation2DBackpropInput}, dy ` + + `must have the same rank as output ${outShape.length}, but got ` + + `${dy.rank}`); + + const $dy = + util.toNestedArray( + outShape, cpuBackend.data.get(dy.dataId).values as TypedArray) as + number[][][][]; + + // The computed gradients has the same dimensions as the input: + // [batch, inputHeight, inputCols, inChannel] + const gradients = + util.makeZerosNestedTypedArray(x.shape, x.dtype) as number[][][][]; + + for (let b = 0; b < batchSize; ++b) { + for (let hOut = 0; hOut < outHeight; ++hOut) { + const hBeg = hOut * strideHeight - padInfo.top; + for (let wOut = 0; wOut < outWidth; ++wOut) { + const wBeg = wOut * strideWidth - padInfo.left; + for (let d = 0; d < inChannels; ++d) { + let curVal = Number.MIN_SAFE_INTEGER; + let hInMax = (hBeg < 0) ? 0 : hBeg; + let wInMax = (wBeg < 0) ? 0 : wBeg; + for (let h = 0; h < filterHeight; ++h) { + const hIn = hBeg + h * dilationHeight; + if (hIn >= 0 && hIn < inHeight) { + for (let w = 0; w < filterWidth; ++w) { + const wIn = wBeg + w * dilationWidth; + if (wIn >= 0 && wIn < inWidth) { + const val = $x[b][hIn][wIn][d] + $filter[h][w][d]; + if (val > curVal) { + curVal = val; + hInMax = hIn; + wInMax = wIn; + } + } + } + } + } + gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d]; + } + } + } + } + + const dataId = cpuBackend.write( + util.toTypedArray(gradients, x.dtype, false /* debug mode */), x.shape, + x.dtype); + + return {dataId, shape: x.shape, dtype: x.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index a1a6a66af82..bb0c38ea531 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -19,6 +19,9 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {dilation2dConfig} from './kernels/Dilation2D'; +import {dilation2dBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter'; +import {dilation2dBackpropInputConfig} from './kernels/Dilation2DBackpropInput'; import {divConfig} from './kernels/Div'; import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; @@ -29,8 +32,10 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig, - transposeConfig, maxPoolWithArgmaxConfig, maxConfig + dilation2dConfig, dilation2dBackpropInputConfig, + dilation2dBackpropFilterConfig, nonMaxSuppressionV5Config, squareConfig, + squaredDifferenceConfig, divConfig, transposeConfig, maxPoolWithArgmaxConfig, + maxConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-webgl/src/setup_test.ts b/tfjs-backend-webgl/src/setup_test.ts index b0bd8043453..ff7f0531290 100644 --- a/tfjs-backend-webgl/src/setup_test.ts +++ b/tfjs-backend-webgl/src/setup_test.ts @@ -26,6 +26,11 @@ const customInclude = (testName: string) => { if (testName.includes(subStr)) { return false; } + + // Not implemented yet. + if (testName.includes('dilation2d')) { + return false; + } } return true; }; diff --git a/tfjs-converter/docs/supported_ops.md b/tfjs-converter/docs/supported_ops.md index cb744bcec0b..b7692994afe 100644 --- a/tfjs-converter/docs/supported_ops.md +++ b/tfjs-converter/docs/supported_ops.md @@ -112,6 +112,7 @@ |Conv2DBackpropInput|conv2dTranspose| |DepthwiseConv2d|depthwiseConv2d| |DepthwiseConv2dNative|depthwiseConv2d| +|Dilation2D|dilation2d| |MaxPool|maxPool| |MaxPool3D|maxPool3d| |Not mapped|pool| diff --git a/tfjs-converter/python/tensorflowjs/op_list/convolution.json b/tfjs-converter/python/tensorflowjs/op_list/convolution.json index e6bfe28c238..31240342989 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/convolution.json +++ b/tfjs-converter/python/tensorflowjs/op_list/convolution.json @@ -627,5 +627,38 @@ "type": "number[]" } ] + }, + { + "tfOpName": "Dilation2D", + "category": "convolution", + "inputs": [ + { + "start": 0, + "name": "x", + "type": "tensor" + }, + { + "start": 1, + "name": "filter", + "type": "tensor" + } + ], + "attrs": [ + { + "tfName": "strides", + "name": "strides", + "type": "number[]" + }, + { + "tfName": "rates", + "name": "dilations", + "type": "number[]" + }, + { + "tfName": "padding", + "name": "pad", + "type": "string" + } + ] } ] \ No newline at end of file diff --git a/tfjs-converter/src/operations/executors/convolution_executor.ts b/tfjs-converter/src/operations/executors/convolution_executor.ts index 0856d63d628..c59620a7c5c 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor.ts @@ -231,6 +231,29 @@ export const executeOp: InternalOpExecutor = (node: Node, [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; } + case 'Dilation2D': { + const strides = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const dilations = + getParamValue('dilations', node, tensorMap, context) as number[]; + + // strides: [1, stride_height, stride_width, 1]. + const strideHeight = strides[1]; + const strideWidth = strides[2]; + + // dilations: [1, dilation_height, dilation_width, 1]. + const dilationHeight = dilations[1]; + const dilationWidth = dilations[2]; + + return [tfc.dilation2d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D, + [strideHeight, strideWidth], pad as 'valid' | 'same', + [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)]; + } + default: throw TypeError(`Node type ${node.op} is not implemented`); } diff --git a/tfjs-converter/src/operations/executors/convolution_executor_test.ts b/tfjs-converter/src/operations/executors/convolution_executor_test.ts index 1ecec05769e..123436dc797 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor_test.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor_test.ts @@ -566,4 +566,26 @@ describe('convolution', () => { }); }); }); + + describe('dilation2d', () => { + it('should call tfc.dilation2d', () => { + spyOn(tfc, 'dilation2d'); + node.op = 'Dilation2D'; + node.inputParams['filter'] = createTensorAttr(1); + node.attrParams['strides'] = createNumericArrayAttr([1, 1, 1, 1]); + node.attrParams['pad'] = createStrAttr('same'); + node.attrParams['dataFormat'] = createStrAttr('NHWC'); + node.attrParams['dilations'] = createNumericArrayAttr([1, 2, 2, 1]); + + const input1 = [tfc.scalar(1.0)]; + const input2 = [tfc.scalar(1.0)]; + node.inputNames = ['input1', 'input2']; + + executeOp(node, {input1, input2}, context); + + expect(tfc.dilation2d) + .toHaveBeenCalledWith( + input1[0], input2[0], [1, 1], 'same', [2, 2], 'NHWC'); + }); + }); }); diff --git a/tfjs-converter/src/operations/op_list/convolution.ts b/tfjs-converter/src/operations/op_list/convolution.ts index 9665afa7a30..d0665b3455a 100644 --- a/tfjs-converter/src/operations/op_list/convolution.ts +++ b/tfjs-converter/src/operations/op_list/convolution.ts @@ -1,8 +1,6 @@ -import {OpMapper} from '../types'; - /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -17,6 +15,8 @@ import {OpMapper} from '../types'; * ============================================================================= */ +import {OpMapper} from '../types'; + export const json: OpMapper[] = [ { 'tfOpName': 'AvgPool', @@ -329,5 +329,18 @@ export const json: OpMapper[] = [ }, {'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]'} ], + }, + { + 'tfOpName': 'Dilation2D', + 'category': 'convolution', + 'inputs': [ + {'start': 0, 'name': 'x', 'type': 'tensor'}, + {'start': 1, 'name': 'filter', 'type': 'tensor'}, + ], + 'attrs': [ + {'tfName': 'strides', 'name': 'strides', 'type': 'number[]'}, + {'tfName': 'rates', 'name': 'dilations', 'type': 'number[]'}, + {'tfName': 'padding', 'name': 'pad', 'type': 'string'} + ] } ]; diff --git a/tfjs-core/src/gradients/Dilation2D_grad.ts b/tfjs-core/src/gradients/Dilation2D_grad.ts new file mode 100644 index 00000000000..9d1c3fa40fc --- /dev/null +++ b/tfjs-core/src/gradients/Dilation2D_grad.ts @@ -0,0 +1,42 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {Dilation2D, Dilation2DBackpropFilter, Dilation2DBackpropFilterInputs, Dilation2DBackpropInput, Dilation2DBackpropInputInputs} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; + +export const dilation2dGradConfig: GradConfig = { + kernelName: Dilation2D, + inputsToSave: ['x', 'filter'], + gradFunc: (dy: Tensor4D, saved: Tensor[], attrs: NamedAttrMap) => { + const [x, filter] = saved as [Tensor4D, Tensor3D]; + + const inputInputs: Dilation2DBackpropInputInputs = {x, filter, dy}; + const filterInputs: Dilation2DBackpropFilterInputs = {x, filter, dy}; + + return { + x: () => ENGINE.runKernel( + Dilation2DBackpropInput, inputInputs as {} as NamedTensorMap, + attrs) as Tensor, + filter: () => ENGINE.runKernel( + Dilation2DBackpropFilter, + filterInputs as {} as NamedTensorMap, attrs) as Tensor + }; + } +}; diff --git a/tfjs-core/src/jasmine_util.ts b/tfjs-core/src/jasmine_util.ts index 2ed8a0f1c02..a8b8cadd8e9 100644 --- a/tfjs-core/src/jasmine_util.ts +++ b/tfjs-core/src/jasmine_util.ts @@ -94,9 +94,26 @@ export interface TestFilter { excludes?: string[]; } +/** + * Add test filtering logic to Jasmine's specFilter hook. + * + * @param testFilters Used for include a test suite, with the ability + * to selectively exclude some of the tests. + * Either `include` or `startsWith` must exist for a `TestFilter`. + * Tests that have the substrings specified by the include or startsWith + * will be included in the test run, unless one of the substrings specified + * by `excludes` appears in the name. + * @param customInclude Function to programatically include a test. + * If this function returns true, a test will immediately run. Otherwise, + * `testFilters` is used for fine-grained filtering. + * + * If a test is not handled by `testFilters` or `customInclude`, the test will + * be excluded in the test run. + */ export function setupTestFilters( testFilters: TestFilter[], customInclude: (name: string) => boolean) { const env = jasmine.getEnv(); + // Account for --grep flag passed to karma by saving the existing specFilter. const grepFilter = env.specFilter; @@ -119,8 +136,7 @@ export function setupTestFilters( return true; } - // Include a describeWithFlags() test from tfjs-core only if the test is - // in the include list. + // Include tests of a test suite unless tests are in excludes list. for (let i = 0; i < testFilters.length; ++i) { const testFilter = testFilters[i]; if ((testFilter.include != null && @@ -137,6 +153,7 @@ export function setupTestFilters( return true; } } + // Otherwise ignore the test. return false; }; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index c00a46a136c..b1a6c5c6d66 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -201,6 +201,14 @@ export interface Dilation2DAttrs { dilations: [number, number]|number; } +export const Dilation2DBackpropInput = 'Dilation2DBackpropInput'; +export type Dilation2DBackpropInputInputs = + Pick; + +export const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; +export type Dilation2DBackpropFilterInputs = + Pick; + export const Div = 'Div'; export type DivInputs = BinaryInputs; diff --git a/tfjs-core/src/ops/conv_util.ts b/tfjs-core/src/ops/conv_util.ts index cffae54dfc4..845b246bff2 100644 --- a/tfjs-core/src/ops/conv_util.ts +++ b/tfjs-core/src/ops/conv_util.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2017 Google Inc. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -73,6 +73,49 @@ export type Conv2DInfo = { filterShape: [number, number, number, number] }; +/** + * + * @param inputShape Input tensor shape is of the following dimensions: + * `[batch, height, width, inChannels]`. + * @param filterShape The filter shape is of the following dimensions: + * `[filterHeight, filterWidth, depth]`. + * @param strides The strides of the sliding window for each dimension of the + * input tensor: `[strideHeight, strideWidth]`. + * If `strides` is a single number, + * then `strideHeight == strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1*1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat The data format of the input and output data. + * Defaults to 'NHWC'. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`. + * Defaults to `[1, 1]`. If `dilations` is a single number, then + * `dilationHeight == dilationWidth`. + */ +export function computeDilation2DInfo( + inputShape: [number, number, number, number], + filterShape: [number, number, number], strides: number|[number, number], + pad: 'same'|'valid'|number, dataFormat: 'NHWC' = 'NHWC', + dilations: number|[number, number]) { + // `computerConv2DInfo` require filterShape to be in the dimension of: + // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have + // outDepth, it assumes input's depth. + // Input shape: [batch, height, width, inChannels] + const inputDepth = inputShape[3]; + const $filterShape = + [...filterShape, inputDepth] as [number, number, number, number]; + const $dataFormat = convertConv2DDataFormat(dataFormat); + + return computeConv2DInfo( + inputShape, $filterShape, strides, dilations, pad, + null /* roundingMode */, null /* depthWise */, $dataFormat); +} + export function computePool2DInfo( inShape: [number, number, number, number], filterSize: [number, number]|number, strides: number|[number, number], diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts index eb24eb36c21..26b6b924df1 100644 --- a/tfjs-core/src/ops/dilation2d.ts +++ b/tfjs-core/src/ops/dilation2d.ts @@ -27,8 +27,6 @@ import * as util from '../util'; import {reshape} from './array_ops'; import {op} from './operation'; - - /** * Computes the grayscale dilation over the input `x`. * @@ -61,9 +59,9 @@ import {op} from './operation'; /** @doc {heading: 'Operations', subheading: 'Basic math'} */ function dilation2d_( x: T|TensorLike, filter: Tensor3D|TensorLike, - strides: [number, number]|number, pad: 'valid'|'same'|number, - dataFormat: 'NHWC' = 'NHWC', - dilations: [number, number]|number = [1, 1]): T { + strides: [number, number]|number, pad: 'valid'|'same', + dilations: [number, number]|number = [1, 1], + dataFormat: 'NHWC' = 'NHWC'): T { const $x = convertToTensor(x, 'x', 'dilation2d'); const $filter = convertToTensor(filter, 'filter', 'dilation2d'); @@ -88,7 +86,7 @@ function dilation2d_( () => `Error in dilation2d: Only NHWC is currently supported, ` + `but got dataFormat of ${dataFormat}`); - const inputs: Dilation2DInputs = {x: $x, filter: $filter}; + const inputs: Dilation2DInputs = {x: x4D, filter: $filter}; const attrs: Dilation2DAttrs = {strides, pad, dataFormat, dilations}; const res = ENGINE.runKernel( diff --git a/tfjs-core/src/ops/dilation2d_test.ts b/tfjs-core/src/ops/dilation2d_test.ts index ddad7c1eb37..2445663d881 100644 --- a/tfjs-core/src/ops/dilation2d_test.ts +++ b/tfjs-core/src/ops/dilation2d_test.ts @@ -20,19 +20,283 @@ import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; import {expectArraysClose} from '../test_util'; describeWithFlags('dilation2d', ALL_ENVS, () => { - it('Should throw error.', async () => { - const inputDepth = 1; - const inputShape: [number, number, number] = [2, 2, inputDepth]; - const filterShape: [number, number, number] = [1, 1, inputDepth]; + it('valid padding.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); - const pad = 'same'; - const stride: [number, number] = [1, 1]; + const result = tf.dilation2d(x, filter, 1 /* strides */, 'valid'); - const x = tf.tensor3d([1, 1, 1, 1], inputShape); - const filter = tf.tensor3d([1], filterShape); + expect(result.shape).toEqual([1, 1, 1, 1]); + expectArraysClose(await result.data(), [.5]); + }); + + it('same padding.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([1, 2, 2, 1]); + expectArraysClose(await result.data(), [.5, .6, .7, .8]); + }); + + it('same padding depth 3.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 3]; + const filterShape: [number, number, number] = [2, 2, 3]; + const x = tf.tensor4d( + [.1, .2, .0, .2, .3, .1, .3, .4, .2, .4, .5, .3], inputShape); + const filter = tf.tensor3d( + [.4, .5, .3, .3, .4, .2, .1, .2, .0, .0, .1, -.1], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([1, 2, 2, 3]); + expectArraysClose( + await result.data(), [.5, .7, .3, .6, .8, .4, .7, .9, .5, .8, 1., .6]); + }); + + it('same padding batch 2.', async () => { + const inputShape: [number, number, number, number] = [2, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .2, .3, .4, .5], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([2, 2, 2, 1]); + expectArraysClose(await result.data(), [.5, .6, .7, .8, .6, .7, .8, .9]); + }); + + it('same padding filter 2.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([1, 3, 3, 1]); + expectArraysClose( + await result.data(), [.7, .8, .7, 1, 1.1, 1, 1.1, 1.2, 1.3]); + }); + + it('valid padding non-square window.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [1, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'valid'); + + expect(result.shape).toEqual([1, 2, 1, 1]); + expectArraysClose(await result.data(), [.5, .7]); + }); + + it('same padding dilations 2.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same', 2); + + // Because dilations = 2, the effective filter is [3, 3, 1]: + // filter_eff = [[[.4], [.0], [.3]], + // [[.0], [.0], [.0]], + // [[.1], [.0], [.2]]] + expect(result.shape).toEqual([1, 3, 3, 1]); + expectArraysClose( + await result.data(), [.7, .8, .6, 1., 1.1, .9, .8, .9, .9]); + }); + + it('valid padding uneven stride.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 4, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d( + [.1, .2, .3, .4, .5, .6, .7, .8, .9, 1., 1.1, 1.2], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + + const result = tf.dilation2d(x, filter, [1, 2] /* strides */, 'valid'); + + expect(result.shape).toEqual([1, 2, 2, 1]); + expectArraysClose(await result.data(), [.8, 1., 1.2, 1.4]); + }); + + it('throws when input rank is not 3 or 4.', async () => { + const filterShape: [number, number, number] = [1, 1, 1]; + // tslint:disable-next-line:no-any + const x: any = tf.tensor1d([.5]); + const filter = tf.tensor3d([.4], filterShape); + + expect(() => tf.dilation2d(x, filter, 1, 'valid')).toThrowError(); + }); + + it('thorws when filter is not rank 3.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number] = [2, 2]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + // tslint:disable-next-line:no-any + const filter: any = tf.tensor2d([.4, .3, .1, .0], filterShape); + + expect(() => tf.dilation2d(x, filter, 1, 'valid')).toThrowError(); + }); + + it('throws when data format is not NHWC.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); + // tslint:disable-next-line:no-any + const dataFormat: any = 'NCHW'; + + expect( + () => tf.dilation2d(x, filter, 1 /* strides */, 'valid', 1, dataFormat)) + .toThrowError(); + }); + + it('gradient valid padding.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [1, 1, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.5], filterShape); + const dy = tf.tensor4d([.2, .3, .4, .2, .1, 1., .2, .3, .4], inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'valid')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [3.1]); + }); + + it('gradient same padding.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [1, 1, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.5], filterShape); + const dy = tf.tensor4d([.2, .3, .4, .2, .1, 1., .2, .3, .4], inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'same')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [3.1]); + }); + + it('gradient same padding depth 2.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 3]; + const filterShape: [number, number, number] = [1, 1, 3]; + const x = tf.tensor4d( + [.1, .2, .0, .2, .3, .1, .3, .4, .2, .4, .5, .3], inputShape); + const filter = tf.tensor3d([.4, .5, .6], filterShape); + const dy = tf.tensor4d( + [.2, .3, .4, .2, .1, 1., .2, .3, .4, .8, -.1, .1], inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'same')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose( + await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4, .8, -.1, .1]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [1.4, .6, 1.9]); + }); + + it('gradient valid padding filter 2.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const dyShape: [number, number, number, number] = [1, 2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + const dy = tf.tensor4d([.2, .3, .4, .2], dyShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'valid')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [0, 0, 0, 0, .2, .3, 0, .4, .2]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [0, 0, 0, 1.1]); + }); + + xit('gradient same padding filter 2.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + // y is: [.7, .8, .7, 1, 1.1, 1, 1.1, 1.2, 1.3] + const dy = tf.tensor4d([.2, .3, .4, .2, .1, .5, 0, .8, .7], inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'same')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + // cpu output is [0, 0, 0, 0, .2, 1.2, 0, 1, .8] + expectArraysClose(await dx.data(), [0, 0, .4, 0, .2, .8, 0, 1, .8]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [2.4, 0, 0, .8]); + }); + + it('gradient same padding filter 2 depth 3.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 3]; + const filterShape: [number, number, number] = [2, 2, 3]; + const x = tf.tensor4d( + [ + .1, .2, .3, .4, .5, .6, .7, .8, .9, .3, .2, .3, .4, .5, + .1, .9, .6, .3, .4, .5, .6, .2, .3, .5, .1, .2, .3 + ], + inputShape); + const filter = tf.tensor3d( + [.4, .3, .1, .2, .2, .1, .7, .3, .8, .4, .9, .1], filterShape); + const dy = tf.tensor4d( + [ + .2, .3, .4, .2, .1, .5, 0, .8, .7, .1, .2, .1, .2, .3, + .4, .5, .6, .6, .6, .7, .8, .3, .2, .1, .2, .4, .2 + ], + inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'same')); + + const [dx, dfilter] = grads([x, filter], dy); - const result = tf.dilation2d(x, filter, stride, pad); + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [ + 0, 0, 0, 0, 0, 0, 0, .8, .5, .2, 0, .4, 0, .3, + 0, .9, .7, .7, .7, .7, .9, .3, .4, .5, .2, .7, .8 + ]); - expectArraysClose(await result.data(), [1, 1, 1, 1]); + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose( + await dfilter.data(), + [1.6, 2.7, 1.1, .2, 0, .5, .3, 0, 2.2, .2, .9, 0]); }); }); diff --git a/tfjs-core/src/public/chained_ops/dilation2d.ts b/tfjs-core/src/public/chained_ops/dilation2d.ts new file mode 100644 index 00000000000..39b2ec5d62e --- /dev/null +++ b/tfjs-core/src/public/chained_ops/dilation2d.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {dilation2d} from '../../ops/dilation2d'; +import {Tensor, Tensor3D, Tensor4D} from '../../tensor'; +import {Rank, TensorLike3D} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + dilation2d( + filter: Tensor3D|TensorLike3D, strides: [number, number]|number, + pad: 'valid'|'same', dilations?: [number, number]|number, + dataFormat?: 'NHWC'): T; + } +} + +Tensor.prototype.dilation2d = function( + filter: Tensor3D|TensorLike3D, strides: [number, number]|number, + pad: 'valid'|'same', dilations?: [number, number]|number, + dataFormat?: 'NHWC'): T { + this.throwIfDisposed(); + return dilation2d(this, filter, strides, pad, dilations, dataFormat) as T; +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 9cf695f7dbd..5c71468725f 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -28,6 +28,7 @@ import './cumsum'; import './depth_to_space'; import './depthwise_conv2d'; import './depthwise_conv2D_deprecated'; +import './dilation2d'; import './div'; import './div_no_nan'; import './dot'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 7a6cf931d77..a35227e811c 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -38,6 +38,7 @@ const CHAINED_OPS = [ 'depthToSpace', 'depthwiseConv2d', 'depthwiseConv2D', + 'dilation2d', 'div', 'divNoNan', 'dot', diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index f7e573682f9..563fa8d81d7 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -28,6 +28,7 @@ import {conv2DBackpropInputGradConfig} from './gradients/Conv2DBackpropInput_gra import {conv3DGradConfig} from './gradients/Conv3D_grad'; import {cumsumGradConfig} from './gradients/Cumsum_grad'; import {depthwiseConv2dNativeGradConfig} from './gradients/DepthwiseConv2dNative_grad'; +import {dilation2dGradConfig} from './gradients/Dilation2D_grad'; import {divGradConfig} from './gradients/Div_grad'; import {eluGradConfig} from './gradients/Elu_grad'; import {floorDivGradConfig} from './gradients/FloorDiv_grad'; @@ -78,6 +79,7 @@ const gradConfigs: GradConfig[] = [ conv3DGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, + dilation2dGradConfig, divGradConfig, eluGradConfig, floorDivGradConfig, diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index 438445dd452..a8bdbc8f819 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -609,7 +609,7 @@ export function toNestedArray(shape: number[], a: TypedArray) { return []; } if (size !== a.length) { - throw new Error(`[${shape}] does not match the input size.`); + throw new Error(`[${shape}] does not match the input size ${a.length}.`); } return createNestedArray(0, shape, a); @@ -643,6 +643,25 @@ export function makeZerosTypedArray( } } +/** + * Make nested `TypedArray` filled with zeros. + * @param shape The shape information for the nested array. + * @param dtype dtype of the array element. + */ +export function makeZerosNestedTypedArray( + shape: number[], dtype: D) { + const size = shape.reduce((prev, curr) => prev * curr, 1); + if (dtype == null || dtype === 'float32' || dtype === 'complex64') { + return toNestedArray(shape, new Float32Array(size)); + } else if (dtype === 'int32') { + return toNestedArray(shape, new Int32Array(size)); + } else if (dtype === 'bool') { + return toNestedArray(shape, new Uint8Array(size)); + } else { + throw new Error(`Unknown data type ${dtype}`); + } +} + /** * Returns the current high-resolution time in milliseconds relative to an * arbitrary time in the past. It works across different platforms (node.js, diff --git a/tfjs-node/src/kernels/Dilation2D.ts b/tfjs-node/src/kernels/Dilation2D.ts new file mode 100644 index 00000000000..24c064672e3 --- /dev/null +++ b/tfjs-node/src/kernels/Dilation2D.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, KernelConfig, registerKernel} from '@tensorflow/tfjs'; + +import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +export const dilation2dConfig: KernelConfig = { + kernelName: Dilation2D, + backendName: 'tensorflow', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter} = inputs as Dilation2DInputs; + const {strides, pad, dataFormat, dilations} = + attrs as {} as Dilation2DAttrs; + const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, dataFormat, + dilations); + const $strides = [1, strideHeight, strideWidth, 1]; + const $dilations = [1, dilationHeight, dilationWidth, 1]; + + const nodeBackend = backend as NodeJSKernelBackend; + + const opAttrs = [ + createTensorsTypeOpAttr('T', x.dtype), + {name: 'strides', type: nodeBackend.binding.TF_ATTR_INT, value: $strides}, + {name: 'rates', type: nodeBackend.binding.TF_ATTR_INT, value: $dilations}, + { + name: 'padding', + type: nodeBackend.binding.TF_ATTR_STRING, + value: padInfo.type + } + ]; + + return nodeBackend.executeSingleOutput(Dilation2D, opAttrs, [x, filter]); + } +}; + +registerKernel(dilation2dConfig); diff --git a/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts new file mode 100644 index 00000000000..3419a7ff6ae --- /dev/null +++ b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropFilter, Dilation2DBackpropFilterInputs, KernelConfig, registerKernel} from '@tensorflow/tfjs'; + +import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +export const dilation2dBackpropFilterConfig: KernelConfig = { + kernelName: Dilation2DBackpropFilter, + backendName: 'tensorflow', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = inputs as Dilation2DBackpropFilterInputs; + const {strides, pad, dataFormat, dilations} = + attrs as {} as Dilation2DAttrs; + const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, dataFormat, + dilations); + const $strides = [1, strideHeight, strideWidth, 1]; + const $dilations = [1, dilationHeight, dilationWidth, 1]; + + const nodeBackend = backend as NodeJSKernelBackend; + + const opAttrs = [ + createTensorsTypeOpAttr('T', x.dtype), + {name: 'strides', type: nodeBackend.binding.TF_ATTR_INT, value: $strides}, + {name: 'rates', type: nodeBackend.binding.TF_ATTR_INT, value: $dilations}, + { + name: 'padding', + type: nodeBackend.binding.TF_ATTR_STRING, + value: padInfo.type + } + ]; + + return nodeBackend.executeSingleOutput( + Dilation2DBackpropFilter, opAttrs, [x, filter, dy]); + } +}; + +registerKernel(dilation2dBackpropFilterConfig); diff --git a/tfjs-node/src/kernels/Dilation2DBackpropInput.ts b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts new file mode 100644 index 00000000000..f9dc14956fd --- /dev/null +++ b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts @@ -0,0 +1,55 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropInput, Dilation2DBackpropInputInputs, KernelConfig, registerKernel} from '@tensorflow/tfjs'; + +import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +export const dilation2dBackpropInputConfig: KernelConfig = { + kernelName: Dilation2DBackpropInput, + backendName: 'tensorflow', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = inputs as Dilation2DBackpropInputInputs; + const {strides, pad, dataFormat, dilations} = + attrs as {} as Dilation2DAttrs; + const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, dataFormat, + dilations); + const $strides = [1, strideHeight, strideWidth, 1]; + const $dilations = [1, dilationHeight, dilationWidth, 1]; + + const nodeBackend = backend as NodeJSKernelBackend; + + const opAttrs = [ + createTensorsTypeOpAttr('T', x.dtype), + {name: 'strides', type: nodeBackend.binding.TF_ATTR_INT, value: $strides}, + {name: 'rates', type: nodeBackend.binding.TF_ATTR_INT, value: $dilations}, + { + name: 'padding', + type: nodeBackend.binding.TF_ATTR_STRING, + value: padInfo.type + } + ]; + + return nodeBackend.executeSingleOutput( + Dilation2DBackpropInput, opAttrs, [x, filter, dy]); + } +}; + +registerKernel(dilation2dBackpropInputConfig); From 26446446092681b83f6ee4c31aa85e444f6d7668 Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 16 Jun 2020 13:35:22 -0700 Subject: [PATCH 06/14] . --- tfjs-core/src/ops/dilation2d_test.ts | 22 ---------------------- 1 file changed, 22 deletions(-) diff --git a/tfjs-core/src/ops/dilation2d_test.ts b/tfjs-core/src/ops/dilation2d_test.ts index 2445663d881..56ad246ed51 100644 --- a/tfjs-core/src/ops/dilation2d_test.ts +++ b/tfjs-core/src/ops/dilation2d_test.ts @@ -242,28 +242,6 @@ describeWithFlags('dilation2d', ALL_ENVS, () => { expectArraysClose(await dfilter.data(), [0, 0, 0, 1.1]); }); - xit('gradient same padding filter 2.', async () => { - const inputShape: [number, number, number, number] = [1, 3, 3, 1]; - const filterShape: [number, number, number] = [2, 2, 1]; - const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); - const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); - // y is: [.7, .8, .7, 1, 1.1, 1, 1.1, 1.2, 1.3] - const dy = tf.tensor4d([.2, .3, .4, .2, .1, .5, 0, .8, .7], inputShape); - - const grads = tf.grads( - (x: tf.Tensor4D, filter: tf.Tensor3D) => - x.dilation2d(filter, 1, 'same')); - - const [dx, dfilter] = grads([x, filter], dy); - - expect(dx.shape).toEqual(x.shape); - // cpu output is [0, 0, 0, 0, .2, 1.2, 0, 1, .8] - expectArraysClose(await dx.data(), [0, 0, .4, 0, .2, .8, 0, 1, .8]); - - expect(dfilter.shape).toEqual(filterShape); - expectArraysClose(await dfilter.data(), [2.4, 0, 0, .8]); - }); - it('gradient same padding filter 2 depth 3.', async () => { const inputShape: [number, number, number, number] = [1, 3, 3, 3]; const filterShape: [number, number, number] = [2, 2, 3]; From f8ece9b8ecb51fb25241b10eac775fbf37c1d4a4 Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 16 Jun 2020 15:36:30 -0700 Subject: [PATCH 07/14] Exclude wasm. --- tfjs-backend-wasm/src/setup_test.ts | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 794acd95e12..0cbcc59242e 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -364,6 +364,11 @@ const TEST_FILTERS: TestFilter[] = [ // Complex numbers not supported yet. excludes: ['complex'], }, + { + startsWith: 'dilation2d', + // Not implemented yet. + excludes: ['dilation2d'] + } ]; const customInclude = (testName: string) => { From 6df888645b82538ef6310cf76715395c9029741a Mon Sep 17 00:00:00 2001 From: Na Li Date: Tue, 16 Jun 2020 17:38:26 -0700 Subject: [PATCH 08/14] Change test name to avoid being included in wasm test. --- tfjs-backend-wasm/src/setup_test.ts | 5 ----- tfjs-core/src/ops/dilation2d_test.ts | 2 +- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 0cbcc59242e..557331c046f 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -363,11 +363,6 @@ const TEST_FILTERS: TestFilter[] = [ startsWith: 'onesLike', // Complex numbers not supported yet. excludes: ['complex'], - }, - { - startsWith: 'dilation2d', - // Not implemented yet. - excludes: ['dilation2d'] } ]; diff --git a/tfjs-core/src/ops/dilation2d_test.ts b/tfjs-core/src/ops/dilation2d_test.ts index 56ad246ed51..24acdea21c2 100644 --- a/tfjs-core/src/ops/dilation2d_test.ts +++ b/tfjs-core/src/ops/dilation2d_test.ts @@ -84,7 +84,7 @@ describeWithFlags('dilation2d', ALL_ENVS, () => { await result.data(), [.7, .8, .7, 1, 1.1, 1, 1.1, 1.2, 1.3]); }); - it('valid padding non-square window.', async () => { + it('valid padding non-square-window.', async () => { const inputShape: [number, number, number, number] = [1, 2, 2, 1]; const filterShape: [number, number, number] = [1, 2, 1]; const x = tf.tensor4d([.1, .2, .3, .4], inputShape); From 7ae1856176dea35cfd3683a9e71652f9fc57e958 Mon Sep 17 00:00:00 2001 From: Na Li Date: Fri, 19 Jun 2020 17:48:48 -0700 Subject: [PATCH 09/14] Resolve merge conflict. --- tfjs-node/src/index.ts | 4 +++- tfjs-node/src/kernels/Dilation2D.ts | 4 +--- tfjs-node/src/kernels/Dilation2DBackpropFilter.ts | 4 +--- tfjs-node/src/kernels/Dilation2DBackpropInput.ts | 4 +--- tfjs-node/src/register_all_kernels.ts | 10 ++++++++-- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tfjs-node/src/index.ts b/tfjs-node/src/index.ts index 97e1ddc0111..24dc544b6cb 100644 --- a/tfjs-node/src/index.ts +++ b/tfjs-node/src/index.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -// Import all kernels. +// Register all kernels. import './register_all_kernels'; import * as tf from '@tensorflow/tfjs'; @@ -66,9 +66,11 @@ export * from './node'; // tslint:disable-next-line:no-require-imports const pjson = require('../package.json'); +// Side effects for default initialization of Node backend. tf.registerBackend('tensorflow', () => { return new NodeJSKernelBackend(bindings as TFJSBinding, pjson.name); }, 3 /* priority */); +import './register_all_kernels'; const success = tf.setBackend('tensorflow'); if (!success) { diff --git a/tfjs-node/src/kernels/Dilation2D.ts b/tfjs-node/src/kernels/Dilation2D.ts index 24c064672e3..113f23ecdbb 100644 --- a/tfjs-node/src/kernels/Dilation2D.ts +++ b/tfjs-node/src/kernels/Dilation2D.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, KernelConfig, registerKernel} from '@tensorflow/tfjs'; +import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, KernelConfig} from '@tensorflow/tfjs'; import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; @@ -50,5 +50,3 @@ export const dilation2dConfig: KernelConfig = { return nodeBackend.executeSingleOutput(Dilation2D, opAttrs, [x, filter]); } }; - -registerKernel(dilation2dConfig); diff --git a/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts index 3419a7ff6ae..71ea6492cb5 100644 --- a/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts +++ b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, Dilation2DAttrs, Dilation2DBackpropFilter, Dilation2DBackpropFilterInputs, KernelConfig, registerKernel} from '@tensorflow/tfjs'; +import {backend_util, Dilation2DAttrs, Dilation2DBackpropFilter, Dilation2DBackpropFilterInputs, KernelConfig} from '@tensorflow/tfjs'; import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; @@ -51,5 +51,3 @@ export const dilation2dBackpropFilterConfig: KernelConfig = { Dilation2DBackpropFilter, opAttrs, [x, filter, dy]); } }; - -registerKernel(dilation2dBackpropFilterConfig); diff --git a/tfjs-node/src/kernels/Dilation2DBackpropInput.ts b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts index f9dc14956fd..b7ce56dfb1a 100644 --- a/tfjs-node/src/kernels/Dilation2DBackpropInput.ts +++ b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {backend_util, Dilation2DAttrs, Dilation2DBackpropInput, Dilation2DBackpropInputInputs, KernelConfig, registerKernel} from '@tensorflow/tfjs'; +import {backend_util, Dilation2DAttrs, Dilation2DBackpropInput, Dilation2DBackpropInputInputs, KernelConfig} from '@tensorflow/tfjs'; import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; @@ -51,5 +51,3 @@ export const dilation2dBackpropInputConfig: KernelConfig = { Dilation2DBackpropInput, opAttrs, [x, filter, dy]); } }; - -registerKernel(dilation2dBackpropInputConfig); diff --git a/tfjs-node/src/register_all_kernels.ts b/tfjs-node/src/register_all_kernels.ts index 56ea721afa0..a9d5d0fc16b 100644 --- a/tfjs-node/src/register_all_kernels.ts +++ b/tfjs-node/src/register_all_kernels.ts @@ -19,13 +19,19 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {dilation2dConfig} from './kernels/Dilation2D'; +import {dilation2dBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter'; +import {dilation2dBackpropInputConfig} from './kernels/Dilation2DBackpropInput'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {softmaxConfig} from './kernels/Softmax'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here -const kernelConfigs: KernelConfig[] = - [nonMaxSuppressionV5Config, softmaxConfig, squaredDifferenceConfig]; +const kernelConfigs: KernelConfig[] = [ + dilation2dConfig, dilation2dBackpropInputConfig, + dilation2dBackpropFilterConfig, nonMaxSuppressionV5Config, softmaxConfig, + squaredDifferenceConfig +]; for (const kernelConfig of kernelConfigs) { registerKernel(kernelConfig); From 490c0d7146ee813c17c11d6b7ea54a4258467d44 Mon Sep 17 00:00:00 2001 From: Na Li Date: Wed, 17 Jun 2020 13:52:01 -0700 Subject: [PATCH 10/14] Add more annotation. --- tfjs-backend-cpu/src/kernels/Dilation2D.ts | 4 ++++ tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts | 5 +++++ tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts | 5 +++++ .../src/operations/executors/convolution_executor_test.ts | 1 - 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tfjs-backend-cpu/src/kernels/Dilation2D.ts b/tfjs-backend-cpu/src/kernels/Dilation2D.ts index 747e03edab7..b90beaa3839 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2D.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2D.ts @@ -63,6 +63,10 @@ export const dilation2dConfig: KernelConfig = { const output = util.makeZerosNestedTypedArray(outShape, x.dtype) as number[][][][]; + // Upsampling the input by fill in `dilation size - 1` values between each + // input value. + // This implementation follows the TF c++ implementation: + // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc for (let b = 0; b < batchSize; ++b) { for (let hOut = 0; hOut < outHeight; ++hOut) { const hBeg = hOut * strideHeight - padInfo.top; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts index d08d3555b96..8a07312ae98 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts @@ -77,6 +77,11 @@ export const dilation2dBackpropFilterConfig: KernelConfig = { const gradients = util.makeZerosNestedTypedArray( filter.shape, filter.dtype) as number[][][]; + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + // This implementation follows the TF c++ implementation: + // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc for (let b = 0; b < batchSize; ++b) { for (let hOut = 0; hOut < outHeight; ++hOut) { const hBeg = hOut * strideHeight - padInfo.top; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts index f4c31a199e7..2467425f923 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts @@ -77,6 +77,11 @@ export const dilation2dBackpropInputConfig: KernelConfig = { const gradients = util.makeZerosNestedTypedArray(x.shape, x.dtype) as number[][][][]; + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + // This implementation follows the TF c++ implementation: + // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc for (let b = 0; b < batchSize; ++b) { for (let hOut = 0; hOut < outHeight; ++hOut) { const hBeg = hOut * strideHeight - padInfo.top; diff --git a/tfjs-converter/src/operations/executors/convolution_executor_test.ts b/tfjs-converter/src/operations/executors/convolution_executor_test.ts index 123436dc797..1fbcb9609b3 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor_test.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor_test.ts @@ -574,7 +574,6 @@ describe('convolution', () => { node.inputParams['filter'] = createTensorAttr(1); node.attrParams['strides'] = createNumericArrayAttr([1, 1, 1, 1]); node.attrParams['pad'] = createStrAttr('same'); - node.attrParams['dataFormat'] = createStrAttr('NHWC'); node.attrParams['dilations'] = createNumericArrayAttr([1, 2, 2, 1]); const input1 = [tfc.scalar(1.0)]; From c9a6030938cf2921ff61eaa30c57916a245cf722 Mon Sep 17 00:00:00 2001 From: Na Li Date: Wed, 17 Jun 2020 14:30:20 -0700 Subject: [PATCH 11/14] Add more annotation. --- tfjs-core/src/ops/conv_util.ts | 4 ++-- tfjs-core/src/util.ts | 2 +- tfjs-node/src/index.ts | 2 ++ 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tfjs-core/src/ops/conv_util.ts b/tfjs-core/src/ops/conv_util.ts index 845b246bff2..48b34896dea 100644 --- a/tfjs-core/src/ops/conv_util.ts +++ b/tfjs-core/src/ops/conv_util.ts @@ -106,9 +106,9 @@ export function computeDilation2DInfo( // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have // outDepth, it assumes input's depth. // Input shape: [batch, height, width, inChannels] - const inputDepth = inputShape[3]; + const inputChannels = inputShape[3]; const $filterShape = - [...filterShape, inputDepth] as [number, number, number, number]; + [...filterShape, inputChannels] as [number, number, number, number]; const $dataFormat = convertConv2DDataFormat(dataFormat); return computeConv2DInfo( diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index a8bdbc8f819..929207bf752 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -651,7 +651,7 @@ export function makeZerosTypedArray( export function makeZerosNestedTypedArray( shape: number[], dtype: D) { const size = shape.reduce((prev, curr) => prev * curr, 1); - if (dtype == null || dtype === 'float32' || dtype === 'complex64') { + if (dtype == null || dtype === 'float32') { return toNestedArray(shape, new Float32Array(size)); } else if (dtype === 'int32') { return toNestedArray(shape, new Int32Array(size)); diff --git a/tfjs-node/src/index.ts b/tfjs-node/src/index.ts index 24dc544b6cb..a79fdae0747 100644 --- a/tfjs-node/src/index.ts +++ b/tfjs-node/src/index.ts @@ -70,6 +70,8 @@ const pjson = require('../package.json'); tf.registerBackend('tensorflow', () => { return new NodeJSKernelBackend(bindings as TFJSBinding, pjson.name); }, 3 /* priority */); + +// Register kernels. import './register_all_kernels'; const success = tf.setBackend('tensorflow'); From f589462fcab229094518389811b5ba7682623c60 Mon Sep 17 00:00:00 2001 From: Na Li Date: Thu, 18 Jun 2020 15:15:23 -0700 Subject: [PATCH 12/14] Add doccomment. --- tfjs-backend-cpu/src/kernels/Dilation2D.ts | 3 +-- tfjs-core/src/ops/conv_util.ts | 2 +- tfjs-core/src/ops/dilation2d.ts | 22 +++++++++++----------- 3 files changed, 13 insertions(+), 14 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Dilation2D.ts b/tfjs-backend-cpu/src/kernels/Dilation2D.ts index b90beaa3839..341a8732d45 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2D.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2D.ts @@ -15,8 +15,7 @@ * ============================================================================= */ -import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, TypedArray, util} from '@tensorflow/tfjs-core'; -import {KernelConfig} from '@tensorflow/tfjs-core'; +import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; import {MathBackendCPU} from '../backend_cpu'; diff --git a/tfjs-core/src/ops/conv_util.ts b/tfjs-core/src/ops/conv_util.ts index 48b34896dea..8650c6dd5d0 100644 --- a/tfjs-core/src/ops/conv_util.ts +++ b/tfjs-core/src/ops/conv_util.ts @@ -104,7 +104,7 @@ export function computeDilation2DInfo( dilations: number|[number, number]) { // `computerConv2DInfo` require filterShape to be in the dimension of: // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have - // outDepth, it assumes input's depth. + // outDepth, it should have the same depth as the input. // Input shape: [batch, height, width, inChannels] const inputChannels = inputShape[3]; const $filterShape = diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts index 26b6b924df1..622df0bfb9c 100644 --- a/tfjs-core/src/ops/dilation2d.ts +++ b/tfjs-core/src/ops/dilation2d.ts @@ -65,18 +65,10 @@ function dilation2d_( const $x = convertToTensor(x, 'x', 'dilation2d'); const $filter = convertToTensor(filter, 'filter', 'dilation2d'); - let x4D = $x as Tensor4D; - let reshapedTo4D = false; - - if ($x.rank === 3) { - x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]); - reshapedTo4D = true; - } - util.assert( - x4D.rank === 4, - () => `Error in dilation2d: input must be rank 4, but got rank ` + - `${x4D.rank}.`); + $x.rank === 3 || $x.rank === 4, + () => `Error in dilation2d: input must be rank 3 or 4, but got rank ` + + `${$x.rank}.`); util.assert( $filter.rank === 3, () => `Error in dilation2d: filter must be rank 3, but got rank ` + @@ -86,6 +78,14 @@ function dilation2d_( () => `Error in dilation2d: Only NHWC is currently supported, ` + `but got dataFormat of ${dataFormat}`); + let x4D = $x as Tensor4D; + let reshapedTo4D = false; + + if ($x.rank === 3) { + x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]); + reshapedTo4D = true; + } + const inputs: Dilation2DInputs = {x: x4D, filter: $filter}; const attrs: Dilation2DAttrs = {strides, pad, dataFormat, dilations}; From 46d39291ef68bd71f9a6bcf7db918ae178c14efe Mon Sep 17 00:00:00 2001 From: Na Li Date: Fri, 19 Jun 2020 17:52:53 -0700 Subject: [PATCH 13/14] . --- tfjs-node/src/index.ts | 3 --- 1 file changed, 3 deletions(-) diff --git a/tfjs-node/src/index.ts b/tfjs-node/src/index.ts index a79fdae0747..d098afcdb2a 100644 --- a/tfjs-node/src/index.ts +++ b/tfjs-node/src/index.ts @@ -71,9 +71,6 @@ tf.registerBackend('tensorflow', () => { return new NodeJSKernelBackend(bindings as TFJSBinding, pjson.name); }, 3 /* priority */); -// Register kernels. -import './register_all_kernels'; - const success = tf.setBackend('tensorflow'); if (!success) { throw new Error(`Could not initialize TensorFlow backend.`); From 73a0bfa626fb41d9d90d4faacf42846d08ed3ef6 Mon Sep 17 00:00:00 2001 From: Na Li Date: Mon, 22 Jun 2020 10:59:18 -0700 Subject: [PATCH 14/14] Remove dataFormat. --- tfjs-backend-cpu/src/kernels/Dilation2D.ts | 7 +++---- tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts | 7 +++---- tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts | 7 +++---- tfjs-core/src/kernel_names.ts | 1 - tfjs-core/src/ops/dilation2d.ts | 2 +- tfjs-node/src/kernels/Dilation2D.ts | 7 +++---- tfjs-node/src/kernels/Dilation2DBackpropFilter.ts | 7 +++---- tfjs-node/src/kernels/Dilation2DBackpropInput.ts | 7 +++---- 8 files changed, 19 insertions(+), 26 deletions(-) diff --git a/tfjs-backend-cpu/src/kernels/Dilation2D.ts b/tfjs-backend-cpu/src/kernels/Dilation2D.ts index 341a8732d45..f5d37064c20 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2D.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2D.ts @@ -24,8 +24,7 @@ export const dilation2dConfig: KernelConfig = { backendName: 'cpu', kernelFunc: ({inputs, backend, attrs}) => { const {x, filter} = inputs as Dilation2DInputs; - const {strides, pad, dataFormat, dilations} = - attrs as {} as Dilation2DAttrs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; const cpuBackend = backend as MathBackendCPU; const $x = @@ -56,8 +55,8 @@ export const dilation2dConfig: KernelConfig = { } = backend_util.computeDilation2DInfo( x.shape as [number, number, number, number], - filter.shape as [number, number, number], strides, pad, dataFormat, - dilations); + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); const output = util.makeZerosNestedTypedArray(outShape, x.dtype) as number[][][][]; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts index 8a07312ae98..2d39c057b77 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts @@ -26,8 +26,7 @@ export const dilation2dBackpropFilterConfig: KernelConfig = { kernelFunc: ({inputs, backend, attrs}) => { const {x, filter, dy} = inputs as {x: Tensor4D, filter: Tensor3D, dy: Tensor4D}; - const {strides, pad, dataFormat, dilations} = - attrs as {} as Dilation2DAttrs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; const cpuBackend = backend as MathBackendCPU; const $x = @@ -58,8 +57,8 @@ export const dilation2dBackpropFilterConfig: KernelConfig = { } = backend_util.computeDilation2DInfo( x.shape as [number, number, number, number], - filter.shape as [number, number, number], strides, pad, dataFormat, - dilations); + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); util.assert( dy.rank === outShape.length, diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts index 2467425f923..3e1405255b9 100644 --- a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts @@ -26,8 +26,7 @@ export const dilation2dBackpropInputConfig: KernelConfig = { kernelFunc: ({inputs, backend, attrs}) => { const {x, filter, dy} = inputs as {x: Tensor4D, filter: Tensor3D, dy: Tensor4D}; - const {strides, pad, dataFormat, dilations} = - attrs as {} as Dilation2DAttrs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; const cpuBackend = backend as MathBackendCPU; const $x = @@ -58,8 +57,8 @@ export const dilation2dBackpropInputConfig: KernelConfig = { } = backend_util.computeDilation2DInfo( x.shape as [number, number, number, number], - filter.shape as [number, number, number], strides, pad, dataFormat, - dilations); + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); util.assert( dy.rank === outShape.length, diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index b1a6c5c6d66..348db2586aa 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -197,7 +197,6 @@ export type Dilation2DInputs = Pick; export interface Dilation2DAttrs { strides: [number, number]|number; pad: 'valid'|'same'|number; - dataFormat: 'NHWC'; dilations: [number, number]|number; } diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts index 622df0bfb9c..794962d6371 100644 --- a/tfjs-core/src/ops/dilation2d.ts +++ b/tfjs-core/src/ops/dilation2d.ts @@ -87,7 +87,7 @@ function dilation2d_( } const inputs: Dilation2DInputs = {x: x4D, filter: $filter}; - const attrs: Dilation2DAttrs = {strides, pad, dataFormat, dilations}; + const attrs: Dilation2DAttrs = {strides, pad, dilations}; const res = ENGINE.runKernel( Dilation2D, inputs as {} as NamedTensorMap, diff --git a/tfjs-node/src/kernels/Dilation2D.ts b/tfjs-node/src/kernels/Dilation2D.ts index 113f23ecdbb..609fb7b9ceb 100644 --- a/tfjs-node/src/kernels/Dilation2D.ts +++ b/tfjs-node/src/kernels/Dilation2D.ts @@ -24,13 +24,12 @@ export const dilation2dConfig: KernelConfig = { backendName: 'tensorflow', kernelFunc: ({inputs, backend, attrs}) => { const {x, filter} = inputs as Dilation2DInputs; - const {strides, pad, dataFormat, dilations} = - attrs as {} as Dilation2DAttrs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = backend_util.computeDilation2DInfo( x.shape as [number, number, number, number], - filter.shape as [number, number, number], strides, pad, dataFormat, - dilations); + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); const $strides = [1, strideHeight, strideWidth, 1]; const $dilations = [1, dilationHeight, dilationWidth, 1]; diff --git a/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts index 71ea6492cb5..7896e03cc90 100644 --- a/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts +++ b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts @@ -24,13 +24,12 @@ export const dilation2dBackpropFilterConfig: KernelConfig = { backendName: 'tensorflow', kernelFunc: ({inputs, backend, attrs}) => { const {x, filter, dy} = inputs as Dilation2DBackpropFilterInputs; - const {strides, pad, dataFormat, dilations} = - attrs as {} as Dilation2DAttrs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = backend_util.computeDilation2DInfo( x.shape as [number, number, number, number], - filter.shape as [number, number, number], strides, pad, dataFormat, - dilations); + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); const $strides = [1, strideHeight, strideWidth, 1]; const $dilations = [1, dilationHeight, dilationWidth, 1]; diff --git a/tfjs-node/src/kernels/Dilation2DBackpropInput.ts b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts index b7ce56dfb1a..61c931dd6d5 100644 --- a/tfjs-node/src/kernels/Dilation2DBackpropInput.ts +++ b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts @@ -24,13 +24,12 @@ export const dilation2dBackpropInputConfig: KernelConfig = { backendName: 'tensorflow', kernelFunc: ({inputs, backend, attrs}) => { const {x, filter, dy} = inputs as Dilation2DBackpropInputInputs; - const {strides, pad, dataFormat, dilations} = - attrs as {} as Dilation2DAttrs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = backend_util.computeDilation2DInfo( x.shape as [number, number, number, number], - filter.shape as [number, number, number], strides, pad, dataFormat, - dilations); + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); const $strides = [1, strideHeight, strideWidth, 1]; const $dilations = [1, dilationHeight, dilationWidth, 1];