diff --git a/tfjs-backend-cpu/run_tests.ts b/tfjs-backend-cpu/run_tests.ts index 822f8f514b1..7152c2e5874 100644 --- a/tfjs-backend-cpu/run_tests.ts +++ b/tfjs-backend-cpu/run_tests.ts @@ -38,14 +38,13 @@ const cpuTests = 'src/**/*_test.ts'; const runner = new jasmineCtor(); runner.loadConfig({spec_files: [cpuTests, coreTests], random: false}); -// customInclude takes higher priority then TEST_FILTERS, only when -// customInclude return false will TEST_FILTERS be considered. const TEST_FILTERS: TestFilter[] = []; const customInclude = (testName: string) => { // Exclude webworker test if (testName.includes('computation in worker')) { return false; } + // Include all other tests. return true; }; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2D.ts b/tfjs-backend-cpu/src/kernels/Dilation2D.ts new file mode 100644 index 00000000000..f5d37064c20 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Dilation2D.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, KernelConfig, TypedArray, util} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export const dilation2dConfig: KernelConfig = { + kernelName: Dilation2D, + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter} = inputs as Dilation2DInputs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; + const cpuBackend = backend as MathBackendCPU; + + const $x = + util.toNestedArray( + x.shape, cpuBackend.data.get(x.dataId).values as TypedArray) as + number[][][][]; + + const $filter = util.toNestedArray( + filter.shape, + cpuBackend.data.get(filter.dataId).values as + TypedArray) as number[][][]; + + const { + batchSize, + inHeight, + inWidth, + inChannels, + outHeight, + outWidth, + padInfo, + strideHeight, + strideWidth, + filterHeight, + filterWidth, + dilationHeight, + dilationWidth, + outShape + } = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); + + const output = + util.makeZerosNestedTypedArray(outShape, x.dtype) as number[][][][]; + + // Upsampling the input by fill in `dilation size - 1` values between each + // input value. + // This implementation follows the TF c++ implementation: + // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc + for (let b = 0; b < batchSize; ++b) { + for (let hOut = 0; hOut < outHeight; ++hOut) { + const hBeg = hOut * strideHeight - padInfo.top; + for (let wOut = 0; wOut < outWidth; ++wOut) { + const wBeg = wOut * strideWidth - padInfo.left; + for (let d = 0; d < inChannels; ++d) { + let curVal = Number.MIN_SAFE_INTEGER; + for (let h = 0; h < filterHeight; ++h) { + const hIn = hBeg + h * dilationHeight; + if (hIn >= 0 && hIn < inHeight) { + for (let w = 0; w < filterWidth; ++w) { + const wIn = wBeg + w * dilationWidth; + if (wIn >= 0 && wIn < inWidth) { + const val = $x[b][hIn][wIn][d] + $filter[h][w][d]; + if (val > curVal) { + curVal = val; + } + } + } + } + } + output[b][hOut][wOut][d] = curVal; + } + } + } + } + + const dataId = cpuBackend.write( + util.toTypedArray(output, x.dtype, false /* debug mode */), outShape, + x.dtype); + + return {dataId, shape: outShape, dtype: x.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts new file mode 100644 index 00000000000..2d39c057b77 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropFilter.ts @@ -0,0 +1,121 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropFilter, Tensor3D, Tensor4D, TypedArray, util} from '@tensorflow/tfjs-core'; +import {KernelConfig} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export const dilation2dBackpropFilterConfig: KernelConfig = { + kernelName: Dilation2DBackpropFilter, + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = + inputs as {x: Tensor4D, filter: Tensor3D, dy: Tensor4D}; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; + const cpuBackend = backend as MathBackendCPU; + + const $x = + util.toNestedArray( + x.shape, cpuBackend.data.get(x.dataId).values as TypedArray) as + number[][][][]; + + const $filter = util.toNestedArray( + filter.shape, + cpuBackend.data.get(filter.dataId).values as + TypedArray) as number[][][]; + + const { + batchSize, + inHeight, + inWidth, + inChannels, + outHeight, + outWidth, + padInfo, + strideHeight, + strideWidth, + filterHeight, + filterWidth, + dilationHeight, + dilationWidth, + outShape + } = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); + + util.assert( + dy.rank === outShape.length, + () => `Error in ${Dilation2DBackpropFilter}, dy ` + + `must have the same rank as output ${outShape.length}, but got ` + + `${dy.rank}`); + + const $dy = + util.toNestedArray( + outShape, cpuBackend.data.get(dy.dataId).values as TypedArray) as + number[][][][]; + + // The computed filter gradients has the same dimensions as the filter: + // [filterHeight, filterWidth, depth] + const gradients = util.makeZerosNestedTypedArray( + filter.shape, filter.dtype) as number[][][]; + + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + // This implementation follows the TF c++ implementation: + // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc + for (let b = 0; b < batchSize; ++b) { + for (let hOut = 0; hOut < outHeight; ++hOut) { + const hBeg = hOut * strideHeight - padInfo.top; + for (let wOut = 0; wOut < outWidth; ++wOut) { + const wBeg = wOut * strideWidth - padInfo.left; + for (let d = 0; d < inChannels; ++d) { + let curVal = Number.MIN_SAFE_INTEGER; + let hMax = 0; + let wMax = 0; + for (let h = 0; h < filterHeight; ++h) { + const hIn = hBeg + h * dilationHeight; + if (hIn >= 0 && hIn < inHeight) { + for (let w = 0; w < filterWidth; ++w) { + const wIn = wBeg + w * dilationWidth; + if (wIn >= 0 && wIn < inWidth) { + const val = $x[b][hIn][wIn][d] + $filter[h][w][d]; + if (val > curVal) { + curVal = val; + hMax = h; + wMax = w; + } + } + } + } + } + gradients[hMax][wMax][d] += $dy[b][hOut][wOut][d]; + } + } + } + } + + const dataId = cpuBackend.write( + util.toTypedArray(gradients, x.dtype, false /* debug mode */), + filter.shape, filter.dtype); + + return {dataId, shape: filter.shape, dtype: filter.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts new file mode 100644 index 00000000000..3e1405255b9 --- /dev/null +++ b/tfjs-backend-cpu/src/kernels/Dilation2DBackpropInput.ts @@ -0,0 +1,121 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropInput, Tensor3D, Tensor4D, TypedArray, util} from '@tensorflow/tfjs-core'; +import {KernelConfig} from '@tensorflow/tfjs-core'; + +import {MathBackendCPU} from '../backend_cpu'; + +export const dilation2dBackpropInputConfig: KernelConfig = { + kernelName: Dilation2DBackpropInput, + backendName: 'cpu', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = + inputs as {x: Tensor4D, filter: Tensor3D, dy: Tensor4D}; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; + const cpuBackend = backend as MathBackendCPU; + + const $x = + util.toNestedArray( + x.shape, cpuBackend.data.get(x.dataId).values as TypedArray) as + number[][][][]; + + const $filter = util.toNestedArray( + filter.shape, + cpuBackend.data.get(filter.dataId).values as + TypedArray) as number[][][]; + + const { + batchSize, + inHeight, + inWidth, + inChannels, + outHeight, + outWidth, + padInfo, + strideHeight, + strideWidth, + filterHeight, + filterWidth, + dilationHeight, + dilationWidth, + outShape + } = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); + + util.assert( + dy.rank === outShape.length, + () => `Error in ${Dilation2DBackpropInput}, dy ` + + `must have the same rank as output ${outShape.length}, but got ` + + `${dy.rank}`); + + const $dy = + util.toNestedArray( + outShape, cpuBackend.data.get(dy.dataId).values as TypedArray) as + number[][][][]; + + // The computed gradients has the same dimensions as the input: + // [batch, inputHeight, inputCols, inChannel] + const gradients = + util.makeZerosNestedTypedArray(x.shape, x.dtype) as number[][][][]; + + // In the case of multiple argmax branches, we only back-propagate along the + // last branch, i.e., the one with largest value of `h * filter_cols + w`, + // similarly to the max-pooling backward routines. + // This implementation follows the TF c++ implementation: + // https://github.com/tensorflow/tensorflow/blob/d9a3a849edc198e90172bc58eb293de457f9d986/tensorflow/core/kernels/dilation_ops.cc + for (let b = 0; b < batchSize; ++b) { + for (let hOut = 0; hOut < outHeight; ++hOut) { + const hBeg = hOut * strideHeight - padInfo.top; + for (let wOut = 0; wOut < outWidth; ++wOut) { + const wBeg = wOut * strideWidth - padInfo.left; + for (let d = 0; d < inChannels; ++d) { + let curVal = Number.MIN_SAFE_INTEGER; + let hInMax = (hBeg < 0) ? 0 : hBeg; + let wInMax = (wBeg < 0) ? 0 : wBeg; + for (let h = 0; h < filterHeight; ++h) { + const hIn = hBeg + h * dilationHeight; + if (hIn >= 0 && hIn < inHeight) { + for (let w = 0; w < filterWidth; ++w) { + const wIn = wBeg + w * dilationWidth; + if (wIn >= 0 && wIn < inWidth) { + const val = $x[b][hIn][wIn][d] + $filter[h][w][d]; + if (val > curVal) { + curVal = val; + hInMax = hIn; + wInMax = wIn; + } + } + } + } + } + gradients[b][hInMax][wInMax][d] += $dy[b][hOut][wOut][d]; + } + } + } + } + + const dataId = cpuBackend.write( + util.toTypedArray(gradients, x.dtype, false /* debug mode */), x.shape, + x.dtype); + + return {dataId, shape: x.shape, dtype: x.dtype}; + } +}; diff --git a/tfjs-backend-cpu/src/register_all_kernels.ts b/tfjs-backend-cpu/src/register_all_kernels.ts index a1a6a66af82..bb0c38ea531 100644 --- a/tfjs-backend-cpu/src/register_all_kernels.ts +++ b/tfjs-backend-cpu/src/register_all_kernels.ts @@ -19,6 +19,9 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {dilation2dConfig} from './kernels/Dilation2D'; +import {dilation2dBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter'; +import {dilation2dBackpropInputConfig} from './kernels/Dilation2DBackpropInput'; import {divConfig} from './kernels/Div'; import {maxConfig} from './kernels/Max'; import {maxPoolWithArgmaxConfig} from './kernels/MaxPoolWithArgmax'; @@ -29,8 +32,10 @@ import {transposeConfig} from './kernels/Transpose'; // List all kernel configs here const kernelConfigs: KernelConfig[] = [ - nonMaxSuppressionV5Config, squareConfig, squaredDifferenceConfig, divConfig, - transposeConfig, maxPoolWithArgmaxConfig, maxConfig + dilation2dConfig, dilation2dBackpropInputConfig, + dilation2dBackpropFilterConfig, nonMaxSuppressionV5Config, squareConfig, + squaredDifferenceConfig, divConfig, transposeConfig, maxPoolWithArgmaxConfig, + maxConfig ]; for (const kernelConfig of kernelConfigs) { diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 794acd95e12..557331c046f 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -363,7 +363,7 @@ const TEST_FILTERS: TestFilter[] = [ startsWith: 'onesLike', // Complex numbers not supported yet. excludes: ['complex'], - }, + } ]; const customInclude = (testName: string) => { diff --git a/tfjs-backend-webgl/src/setup_test.ts b/tfjs-backend-webgl/src/setup_test.ts index b0bd8043453..ff7f0531290 100644 --- a/tfjs-backend-webgl/src/setup_test.ts +++ b/tfjs-backend-webgl/src/setup_test.ts @@ -26,6 +26,11 @@ const customInclude = (testName: string) => { if (testName.includes(subStr)) { return false; } + + // Not implemented yet. + if (testName.includes('dilation2d')) { + return false; + } } return true; }; diff --git a/tfjs-converter/docs/supported_ops.md b/tfjs-converter/docs/supported_ops.md index cb744bcec0b..b7692994afe 100644 --- a/tfjs-converter/docs/supported_ops.md +++ b/tfjs-converter/docs/supported_ops.md @@ -112,6 +112,7 @@ |Conv2DBackpropInput|conv2dTranspose| |DepthwiseConv2d|depthwiseConv2d| |DepthwiseConv2dNative|depthwiseConv2d| +|Dilation2D|dilation2d| |MaxPool|maxPool| |MaxPool3D|maxPool3d| |Not mapped|pool| diff --git a/tfjs-converter/python/tensorflowjs/op_list/convolution.json b/tfjs-converter/python/tensorflowjs/op_list/convolution.json index e6bfe28c238..31240342989 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/convolution.json +++ b/tfjs-converter/python/tensorflowjs/op_list/convolution.json @@ -627,5 +627,38 @@ "type": "number[]" } ] + }, + { + "tfOpName": "Dilation2D", + "category": "convolution", + "inputs": [ + { + "start": 0, + "name": "x", + "type": "tensor" + }, + { + "start": 1, + "name": "filter", + "type": "tensor" + } + ], + "attrs": [ + { + "tfName": "strides", + "name": "strides", + "type": "number[]" + }, + { + "tfName": "rates", + "name": "dilations", + "type": "number[]" + }, + { + "tfName": "padding", + "name": "pad", + "type": "string" + } + ] } ] \ No newline at end of file diff --git a/tfjs-converter/src/operations/executors/convolution_executor.ts b/tfjs-converter/src/operations/executors/convolution_executor.ts index 0856d63d628..c59620a7c5c 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor.ts @@ -231,6 +231,29 @@ export const executeOp: InternalOpExecutor = (node: Node, [stride[1], stride[2], stride[3]], pad as 'valid' | 'same')]; } + case 'Dilation2D': { + const strides = + getParamValue('strides', node, tensorMap, context) as number[]; + const pad = getParamValue('pad', node, tensorMap, context); + const dilations = + getParamValue('dilations', node, tensorMap, context) as number[]; + + // strides: [1, stride_height, stride_width, 1]. + const strideHeight = strides[1]; + const strideWidth = strides[2]; + + // dilations: [1, dilation_height, dilation_width, 1]. + const dilationHeight = dilations[1]; + const dilationWidth = dilations[2]; + + return [tfc.dilation2d( + getParamValue('x', node, tensorMap, context) as tfc.Tensor3D | + tfc.Tensor4D, + getParamValue('filter', node, tensorMap, context) as tfc.Tensor3D, + [strideHeight, strideWidth], pad as 'valid' | 'same', + [dilationHeight, dilationWidth], 'NHWC' /* dataFormat */)]; + } + default: throw TypeError(`Node type ${node.op} is not implemented`); } diff --git a/tfjs-converter/src/operations/executors/convolution_executor_test.ts b/tfjs-converter/src/operations/executors/convolution_executor_test.ts index 1ecec05769e..1fbcb9609b3 100644 --- a/tfjs-converter/src/operations/executors/convolution_executor_test.ts +++ b/tfjs-converter/src/operations/executors/convolution_executor_test.ts @@ -566,4 +566,25 @@ describe('convolution', () => { }); }); }); + + describe('dilation2d', () => { + it('should call tfc.dilation2d', () => { + spyOn(tfc, 'dilation2d'); + node.op = 'Dilation2D'; + node.inputParams['filter'] = createTensorAttr(1); + node.attrParams['strides'] = createNumericArrayAttr([1, 1, 1, 1]); + node.attrParams['pad'] = createStrAttr('same'); + node.attrParams['dilations'] = createNumericArrayAttr([1, 2, 2, 1]); + + const input1 = [tfc.scalar(1.0)]; + const input2 = [tfc.scalar(1.0)]; + node.inputNames = ['input1', 'input2']; + + executeOp(node, {input1, input2}, context); + + expect(tfc.dilation2d) + .toHaveBeenCalledWith( + input1[0], input2[0], [1, 1], 'same', [2, 2], 'NHWC'); + }); + }); }); diff --git a/tfjs-converter/src/operations/op_list/convolution.ts b/tfjs-converter/src/operations/op_list/convolution.ts index 9665afa7a30..d0665b3455a 100644 --- a/tfjs-converter/src/operations/op_list/convolution.ts +++ b/tfjs-converter/src/operations/op_list/convolution.ts @@ -1,8 +1,6 @@ -import {OpMapper} from '../types'; - /** * @license - * Copyright 2018 Google LLC. All Rights Reserved. + * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -17,6 +15,8 @@ import {OpMapper} from '../types'; * ============================================================================= */ +import {OpMapper} from '../types'; + export const json: OpMapper[] = [ { 'tfOpName': 'AvgPool', @@ -329,5 +329,18 @@ export const json: OpMapper[] = [ }, {'tfName': 'dilations', 'name': 'dilations', 'type': 'number[]'} ], + }, + { + 'tfOpName': 'Dilation2D', + 'category': 'convolution', + 'inputs': [ + {'start': 0, 'name': 'x', 'type': 'tensor'}, + {'start': 1, 'name': 'filter', 'type': 'tensor'}, + ], + 'attrs': [ + {'tfName': 'strides', 'name': 'strides', 'type': 'number[]'}, + {'tfName': 'rates', 'name': 'dilations', 'type': 'number[]'}, + {'tfName': 'padding', 'name': 'pad', 'type': 'string'} + ] } ]; diff --git a/tfjs-core/src/gradients/Dilation2D_grad.ts b/tfjs-core/src/gradients/Dilation2D_grad.ts new file mode 100644 index 00000000000..9d1c3fa40fc --- /dev/null +++ b/tfjs-core/src/gradients/Dilation2D_grad.ts @@ -0,0 +1,42 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {ENGINE} from '../engine'; +import {Dilation2D, Dilation2DBackpropFilter, Dilation2DBackpropFilterInputs, Dilation2DBackpropInput, Dilation2DBackpropInputInputs} from '../kernel_names'; +import {GradConfig} from '../kernel_registry'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor, Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; + +export const dilation2dGradConfig: GradConfig = { + kernelName: Dilation2D, + inputsToSave: ['x', 'filter'], + gradFunc: (dy: Tensor4D, saved: Tensor[], attrs: NamedAttrMap) => { + const [x, filter] = saved as [Tensor4D, Tensor3D]; + + const inputInputs: Dilation2DBackpropInputInputs = {x, filter, dy}; + const filterInputs: Dilation2DBackpropFilterInputs = {x, filter, dy}; + + return { + x: () => ENGINE.runKernel( + Dilation2DBackpropInput, inputInputs as {} as NamedTensorMap, + attrs) as Tensor, + filter: () => ENGINE.runKernel( + Dilation2DBackpropFilter, + filterInputs as {} as NamedTensorMap, attrs) as Tensor + }; + } +}; diff --git a/tfjs-core/src/jasmine_util.ts b/tfjs-core/src/jasmine_util.ts index 2ed8a0f1c02..a8b8cadd8e9 100644 --- a/tfjs-core/src/jasmine_util.ts +++ b/tfjs-core/src/jasmine_util.ts @@ -94,9 +94,26 @@ export interface TestFilter { excludes?: string[]; } +/** + * Add test filtering logic to Jasmine's specFilter hook. + * + * @param testFilters Used for include a test suite, with the ability + * to selectively exclude some of the tests. + * Either `include` or `startsWith` must exist for a `TestFilter`. + * Tests that have the substrings specified by the include or startsWith + * will be included in the test run, unless one of the substrings specified + * by `excludes` appears in the name. + * @param customInclude Function to programatically include a test. + * If this function returns true, a test will immediately run. Otherwise, + * `testFilters` is used for fine-grained filtering. + * + * If a test is not handled by `testFilters` or `customInclude`, the test will + * be excluded in the test run. + */ export function setupTestFilters( testFilters: TestFilter[], customInclude: (name: string) => boolean) { const env = jasmine.getEnv(); + // Account for --grep flag passed to karma by saving the existing specFilter. const grepFilter = env.specFilter; @@ -119,8 +136,7 @@ export function setupTestFilters( return true; } - // Include a describeWithFlags() test from tfjs-core only if the test is - // in the include list. + // Include tests of a test suite unless tests are in excludes list. for (let i = 0; i < testFilters.length; ++i) { const testFilter = testFilters[i]; if ((testFilter.include != null && @@ -137,6 +153,7 @@ export function setupTestFilters( return true; } } + // Otherwise ignore the test. return false; }; diff --git a/tfjs-core/src/kernel_names.ts b/tfjs-core/src/kernel_names.ts index b8c5f3bd845..348db2586aa 100644 --- a/tfjs-core/src/kernel_names.ts +++ b/tfjs-core/src/kernel_names.ts @@ -192,6 +192,22 @@ export type DepthwiseConv2dNativeBackpropInputInputs = export const Diag = 'Diag'; export type DiagInputs = Pick; +export const Dilation2D = 'Dilation2D'; +export type Dilation2DInputs = Pick; +export interface Dilation2DAttrs { + strides: [number, number]|number; + pad: 'valid'|'same'|number; + dilations: [number, number]|number; +} + +export const Dilation2DBackpropInput = 'Dilation2DBackpropInput'; +export type Dilation2DBackpropInputInputs = + Pick; + +export const Dilation2DBackpropFilter = 'Dilation2DBackpropFilter'; +export type Dilation2DBackpropFilterInputs = + Pick; + export const Div = 'Div'; export type DivInputs = BinaryInputs; diff --git a/tfjs-core/src/ops/conv_util.ts b/tfjs-core/src/ops/conv_util.ts index cffae54dfc4..8650c6dd5d0 100644 --- a/tfjs-core/src/ops/conv_util.ts +++ b/tfjs-core/src/ops/conv_util.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2017 Google Inc. All Rights Reserved. + * Copyright 2020 Google Inc. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -73,6 +73,49 @@ export type Conv2DInfo = { filterShape: [number, number, number, number] }; +/** + * + * @param inputShape Input tensor shape is of the following dimensions: + * `[batch, height, width, inChannels]`. + * @param filterShape The filter shape is of the following dimensions: + * `[filterHeight, filterWidth, depth]`. + * @param strides The strides of the sliding window for each dimension of the + * input tensor: `[strideHeight, strideWidth]`. + * If `strides` is a single number, + * then `strideHeight == strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1*1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat The data format of the input and output data. + * Defaults to 'NHWC'. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`. + * Defaults to `[1, 1]`. If `dilations` is a single number, then + * `dilationHeight == dilationWidth`. + */ +export function computeDilation2DInfo( + inputShape: [number, number, number, number], + filterShape: [number, number, number], strides: number|[number, number], + pad: 'same'|'valid'|number, dataFormat: 'NHWC' = 'NHWC', + dilations: number|[number, number]) { + // `computerConv2DInfo` require filterShape to be in the dimension of: + // `[filterHeight, filterWidth, depth, outDepth]`, dilation2d doesn't have + // outDepth, it should have the same depth as the input. + // Input shape: [batch, height, width, inChannels] + const inputChannels = inputShape[3]; + const $filterShape = + [...filterShape, inputChannels] as [number, number, number, number]; + const $dataFormat = convertConv2DDataFormat(dataFormat); + + return computeConv2DInfo( + inputShape, $filterShape, strides, dilations, pad, + null /* roundingMode */, null /* depthWise */, $dataFormat); +} + export function computePool2DInfo( inShape: [number, number, number, number], filterSize: [number, number]|number, strides: number|[number, number], diff --git a/tfjs-core/src/ops/dilation2d.ts b/tfjs-core/src/ops/dilation2d.ts new file mode 100644 index 00000000000..794962d6371 --- /dev/null +++ b/tfjs-core/src/ops/dilation2d.ts @@ -0,0 +1,103 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {ENGINE} from '../engine'; +import {Dilation2D, Dilation2DAttrs, Dilation2DInputs} from '../kernel_names'; +import {NamedAttrMap} from '../kernel_registry'; +import {Tensor3D, Tensor4D} from '../tensor'; +import {NamedTensorMap} from '../tensor_types'; +import {convertToTensor} from '../tensor_util_env'; +import {TensorLike} from '../types'; +import * as util from '../util'; + +import {reshape} from './array_ops'; +import {op} from './operation'; + +/** + * Computes the grayscale dilation over the input `x`. + * + * @param x The input tensor, rank 3 or rank 4 of shape + * `[batch, height, width, inChannels]`. If rank 3, batch of 1 is assumed. + * @param filter The filter tensor, rank 3, of shape + * `[filterHeight, filterWidth, depth]`. + * @param strides The strides of the sliding window for each dimension of the + * input tensor: `[strideHeight, strideWidth]`. + * If `strides` is a single number, + * then `strideHeight == strideWidth`. + * @param pad The type of padding algorithm. + * - `same` and stride 1: output will be of same size as input, + * regardless of filter size. + * - `valid`: output will be smaller than input if filter is larger + * than 1*1x1. + * - For more info, see this guide: + * [https://www.tensorflow.org/api_guides/python/nn#Convolution]( + * https://www.tensorflow.org/api_guides/python/nn#Convolution) + * @param dataFormat Specify the data format of the input and output data. + * Defaults to 'NHWC'. Only 'NHWC' is currently supported. With the + * default format "NHWC", the data is stored in the order of: [batch, + * height, width, channels]. + * @param dilations The dilation rates: `[dilationHeight, dilationWidth]` + * in which we sample input values across the height and width dimensions + * for atrous morphological dilation. Defaults to `[1, 1]`. If `dilations` + * is a single number, then `dilationHeight == dilationWidth`. If it is + * greater than 1, then all values of `strides` must be 1. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function dilation2d_( + x: T|TensorLike, filter: Tensor3D|TensorLike, + strides: [number, number]|number, pad: 'valid'|'same', + dilations: [number, number]|number = [1, 1], + dataFormat: 'NHWC' = 'NHWC'): T { + const $x = convertToTensor(x, 'x', 'dilation2d'); + const $filter = convertToTensor(filter, 'filter', 'dilation2d'); + + util.assert( + $x.rank === 3 || $x.rank === 4, + () => `Error in dilation2d: input must be rank 3 or 4, but got rank ` + + `${$x.rank}.`); + util.assert( + $filter.rank === 3, + () => `Error in dilation2d: filter must be rank 3, but got rank ` + + `${$filter.rank}.`); + util.assert( + dataFormat === 'NHWC', + () => `Error in dilation2d: Only NHWC is currently supported, ` + + `but got dataFormat of ${dataFormat}`); + + let x4D = $x as Tensor4D; + let reshapedTo4D = false; + + if ($x.rank === 3) { + x4D = reshape($x, [1, $x.shape[0], $x.shape[1], $x.shape[2]]); + reshapedTo4D = true; + } + + const inputs: Dilation2DInputs = {x: x4D, filter: $filter}; + const attrs: Dilation2DAttrs = {strides, pad, dilations}; + + const res = ENGINE.runKernel( + Dilation2D, inputs as {} as NamedTensorMap, + attrs as {} as NamedAttrMap) as T; + + if (reshapedTo4D) { + return reshape(res, [res.shape[1], res.shape[2], res.shape[3]]) as T; + } + + return res; +} + +export const dilation2d = op({dilation2d_}); diff --git a/tfjs-core/src/ops/dilation2d_test.ts b/tfjs-core/src/ops/dilation2d_test.ts new file mode 100644 index 00000000000..24acdea21c2 --- /dev/null +++ b/tfjs-core/src/ops/dilation2d_test.ts @@ -0,0 +1,280 @@ +/** + * @license + * Copyright 2020 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; +import {expectArraysClose} from '../test_util'; + +describeWithFlags('dilation2d', ALL_ENVS, () => { + it('valid padding.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'valid'); + + expect(result.shape).toEqual([1, 1, 1, 1]); + expectArraysClose(await result.data(), [.5]); + }); + + it('same padding.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([1, 2, 2, 1]); + expectArraysClose(await result.data(), [.5, .6, .7, .8]); + }); + + it('same padding depth 3.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 3]; + const filterShape: [number, number, number] = [2, 2, 3]; + const x = tf.tensor4d( + [.1, .2, .0, .2, .3, .1, .3, .4, .2, .4, .5, .3], inputShape); + const filter = tf.tensor3d( + [.4, .5, .3, .3, .4, .2, .1, .2, .0, .0, .1, -.1], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([1, 2, 2, 3]); + expectArraysClose( + await result.data(), [.5, .7, .3, .6, .8, .4, .7, .9, .5, .8, 1., .6]); + }); + + it('same padding batch 2.', async () => { + const inputShape: [number, number, number, number] = [2, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .2, .3, .4, .5], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([2, 2, 2, 1]); + expectArraysClose(await result.data(), [.5, .6, .7, .8, .6, .7, .8, .9]); + }); + + it('same padding filter 2.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same'); + + expect(result.shape).toEqual([1, 3, 3, 1]); + expectArraysClose( + await result.data(), [.7, .8, .7, 1, 1.1, 1, 1.1, 1.2, 1.3]); + }); + + it('valid padding non-square-window.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [1, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'valid'); + + expect(result.shape).toEqual([1, 2, 1, 1]); + expectArraysClose(await result.data(), [.5, .7]); + }); + + it('same padding dilations 2.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + + const result = tf.dilation2d(x, filter, 1 /* strides */, 'same', 2); + + // Because dilations = 2, the effective filter is [3, 3, 1]: + // filter_eff = [[[.4], [.0], [.3]], + // [[.0], [.0], [.0]], + // [[.1], [.0], [.2]]] + expect(result.shape).toEqual([1, 3, 3, 1]); + expectArraysClose( + await result.data(), [.7, .8, .6, 1., 1.1, .9, .8, .9, .9]); + }); + + it('valid padding uneven stride.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 4, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d( + [.1, .2, .3, .4, .5, .6, .7, .8, .9, 1., 1.1, 1.2], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + + const result = tf.dilation2d(x, filter, [1, 2] /* strides */, 'valid'); + + expect(result.shape).toEqual([1, 2, 2, 1]); + expectArraysClose(await result.data(), [.8, 1., 1.2, 1.4]); + }); + + it('throws when input rank is not 3 or 4.', async () => { + const filterShape: [number, number, number] = [1, 1, 1]; + // tslint:disable-next-line:no-any + const x: any = tf.tensor1d([.5]); + const filter = tf.tensor3d([.4], filterShape); + + expect(() => tf.dilation2d(x, filter, 1, 'valid')).toThrowError(); + }); + + it('thorws when filter is not rank 3.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number] = [2, 2]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + // tslint:disable-next-line:no-any + const filter: any = tf.tensor2d([.4, .3, .1, .0], filterShape); + + expect(() => tf.dilation2d(x, filter, 1, 'valid')).toThrowError(); + }); + + it('throws when data format is not NHWC.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .0], filterShape); + // tslint:disable-next-line:no-any + const dataFormat: any = 'NCHW'; + + expect( + () => tf.dilation2d(x, filter, 1 /* strides */, 'valid', 1, dataFormat)) + .toThrowError(); + }); + + it('gradient valid padding.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [1, 1, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.5], filterShape); + const dy = tf.tensor4d([.2, .3, .4, .2, .1, 1., .2, .3, .4], inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'valid')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [3.1]); + }); + + it('gradient same padding.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [1, 1, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.5], filterShape); + const dy = tf.tensor4d([.2, .3, .4, .2, .1, 1., .2, .3, .4], inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'same')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [3.1]); + }); + + it('gradient same padding depth 2.', async () => { + const inputShape: [number, number, number, number] = [1, 2, 2, 3]; + const filterShape: [number, number, number] = [1, 1, 3]; + const x = tf.tensor4d( + [.1, .2, .0, .2, .3, .1, .3, .4, .2, .4, .5, .3], inputShape); + const filter = tf.tensor3d([.4, .5, .6], filterShape); + const dy = tf.tensor4d( + [.2, .3, .4, .2, .1, 1., .2, .3, .4, .8, -.1, .1], inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'same')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose( + await dx.data(), [.2, .3, .4, .2, .1, 1., .2, .3, .4, .8, -.1, .1]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [1.4, .6, 1.9]); + }); + + it('gradient valid padding filter 2.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 1]; + const filterShape: [number, number, number] = [2, 2, 1]; + const dyShape: [number, number, number, number] = [1, 2, 2, 1]; + const x = tf.tensor4d([.1, .2, .3, .4, .5, .6, .7, .8, .9], inputShape); + const filter = tf.tensor3d([.4, .3, .1, .2], filterShape); + const dy = tf.tensor4d([.2, .3, .4, .2], dyShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'valid')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [0, 0, 0, 0, .2, .3, 0, .4, .2]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose(await dfilter.data(), [0, 0, 0, 1.1]); + }); + + it('gradient same padding filter 2 depth 3.', async () => { + const inputShape: [number, number, number, number] = [1, 3, 3, 3]; + const filterShape: [number, number, number] = [2, 2, 3]; + const x = tf.tensor4d( + [ + .1, .2, .3, .4, .5, .6, .7, .8, .9, .3, .2, .3, .4, .5, + .1, .9, .6, .3, .4, .5, .6, .2, .3, .5, .1, .2, .3 + ], + inputShape); + const filter = tf.tensor3d( + [.4, .3, .1, .2, .2, .1, .7, .3, .8, .4, .9, .1], filterShape); + const dy = tf.tensor4d( + [ + .2, .3, .4, .2, .1, .5, 0, .8, .7, .1, .2, .1, .2, .3, + .4, .5, .6, .6, .6, .7, .8, .3, .2, .1, .2, .4, .2 + ], + inputShape); + + const grads = tf.grads( + (x: tf.Tensor4D, filter: tf.Tensor3D) => + x.dilation2d(filter, 1, 'same')); + + const [dx, dfilter] = grads([x, filter], dy); + + expect(dx.shape).toEqual(x.shape); + expectArraysClose(await dx.data(), [ + 0, 0, 0, 0, 0, 0, 0, .8, .5, .2, 0, .4, 0, .3, + 0, .9, .7, .7, .7, .7, .9, .3, .4, .5, .2, .7, .8 + ]); + + expect(dfilter.shape).toEqual(filterShape); + expectArraysClose( + await dfilter.data(), + [1.6, 2.7, 1.1, .2, 0, .5, .3, 0, 2.2, .2, .9, 0]); + }); +}); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index ab66c92d783..490d1d9ae94 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -43,6 +43,7 @@ export {cumsum} from './cumsum'; export {depthToSpace} from './depth_to_space'; export {depthwiseConv2d} from './depthwise_conv2d'; export {diag} from './diag'; +export {dilation2d} from './dilation2d'; export {div} from './div'; export {divNoNan} from './div_no_nan'; export {dot} from './dot'; diff --git a/tfjs-core/src/public/chained_ops/dilation2d.ts b/tfjs-core/src/public/chained_ops/dilation2d.ts new file mode 100644 index 00000000000..39b2ec5d62e --- /dev/null +++ b/tfjs-core/src/public/chained_ops/dilation2d.ts @@ -0,0 +1,36 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ +import {dilation2d} from '../../ops/dilation2d'; +import {Tensor, Tensor3D, Tensor4D} from '../../tensor'; +import {Rank, TensorLike3D} from '../../types'; + +declare module '../../tensor' { + interface Tensor { + dilation2d( + filter: Tensor3D|TensorLike3D, strides: [number, number]|number, + pad: 'valid'|'same', dilations?: [number, number]|number, + dataFormat?: 'NHWC'): T; + } +} + +Tensor.prototype.dilation2d = function( + filter: Tensor3D|TensorLike3D, strides: [number, number]|number, + pad: 'valid'|'same', dilations?: [number, number]|number, + dataFormat?: 'NHWC'): T { + this.throwIfDisposed(); + return dilation2d(this, filter, strides, pad, dilations, dataFormat) as T; +}; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts index 9cf695f7dbd..5c71468725f 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops.ts @@ -28,6 +28,7 @@ import './cumsum'; import './depth_to_space'; import './depthwise_conv2d'; import './depthwise_conv2D_deprecated'; +import './dilation2d'; import './div'; import './div_no_nan'; import './dot'; diff --git a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts index 7a6cf931d77..a35227e811c 100644 --- a/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts +++ b/tfjs-core/src/public/chained_ops/register_all_chained_ops_test.ts @@ -38,6 +38,7 @@ const CHAINED_OPS = [ 'depthToSpace', 'depthwiseConv2d', 'depthwiseConv2D', + 'dilation2d', 'div', 'divNoNan', 'dot', diff --git a/tfjs-core/src/register_all_gradients.ts b/tfjs-core/src/register_all_gradients.ts index f7e573682f9..563fa8d81d7 100644 --- a/tfjs-core/src/register_all_gradients.ts +++ b/tfjs-core/src/register_all_gradients.ts @@ -28,6 +28,7 @@ import {conv2DBackpropInputGradConfig} from './gradients/Conv2DBackpropInput_gra import {conv3DGradConfig} from './gradients/Conv3D_grad'; import {cumsumGradConfig} from './gradients/Cumsum_grad'; import {depthwiseConv2dNativeGradConfig} from './gradients/DepthwiseConv2dNative_grad'; +import {dilation2dGradConfig} from './gradients/Dilation2D_grad'; import {divGradConfig} from './gradients/Div_grad'; import {eluGradConfig} from './gradients/Elu_grad'; import {floorDivGradConfig} from './gradients/FloorDiv_grad'; @@ -78,6 +79,7 @@ const gradConfigs: GradConfig[] = [ conv3DGradConfig, cumsumGradConfig, depthwiseConv2dNativeGradConfig, + dilation2dGradConfig, divGradConfig, eluGradConfig, floorDivGradConfig, diff --git a/tfjs-core/src/tests.ts b/tfjs-core/src/tests.ts index f144ae92357..6a6492563d7 100644 --- a/tfjs-core/src/tests.ts +++ b/tfjs-core/src/tests.ts @@ -69,6 +69,7 @@ import './ops/conv_util_test'; import './ops/cumsum_test'; import './ops/depth_to_space_test'; import './ops/diag_test'; +import './ops/dilation2d_test'; import './ops/dropout_test'; import './ops/dropout_util_test'; import './ops/elu_test'; diff --git a/tfjs-core/src/util.ts b/tfjs-core/src/util.ts index 438445dd452..929207bf752 100644 --- a/tfjs-core/src/util.ts +++ b/tfjs-core/src/util.ts @@ -609,7 +609,7 @@ export function toNestedArray(shape: number[], a: TypedArray) { return []; } if (size !== a.length) { - throw new Error(`[${shape}] does not match the input size.`); + throw new Error(`[${shape}] does not match the input size ${a.length}.`); } return createNestedArray(0, shape, a); @@ -643,6 +643,25 @@ export function makeZerosTypedArray( } } +/** + * Make nested `TypedArray` filled with zeros. + * @param shape The shape information for the nested array. + * @param dtype dtype of the array element. + */ +export function makeZerosNestedTypedArray( + shape: number[], dtype: D) { + const size = shape.reduce((prev, curr) => prev * curr, 1); + if (dtype == null || dtype === 'float32') { + return toNestedArray(shape, new Float32Array(size)); + } else if (dtype === 'int32') { + return toNestedArray(shape, new Int32Array(size)); + } else if (dtype === 'bool') { + return toNestedArray(shape, new Uint8Array(size)); + } else { + throw new Error(`Unknown data type ${dtype}`); + } +} + /** * Returns the current high-resolution time in milliseconds relative to an * arbitrary time in the past. It works across different platforms (node.js, diff --git a/tfjs-node/src/index.ts b/tfjs-node/src/index.ts index 97e1ddc0111..d098afcdb2a 100644 --- a/tfjs-node/src/index.ts +++ b/tfjs-node/src/index.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -// Import all kernels. +// Register all kernels. import './register_all_kernels'; import * as tf from '@tensorflow/tfjs'; @@ -66,6 +66,7 @@ export * from './node'; // tslint:disable-next-line:no-require-imports const pjson = require('../package.json'); +// Side effects for default initialization of Node backend. tf.registerBackend('tensorflow', () => { return new NodeJSKernelBackend(bindings as TFJSBinding, pjson.name); }, 3 /* priority */); diff --git a/tfjs-node/src/kernels/Dilation2D.ts b/tfjs-node/src/kernels/Dilation2D.ts new file mode 100644 index 00000000000..609fb7b9ceb --- /dev/null +++ b/tfjs-node/src/kernels/Dilation2D.ts @@ -0,0 +1,51 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2D, Dilation2DAttrs, Dilation2DInputs, KernelConfig} from '@tensorflow/tfjs'; + +import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +export const dilation2dConfig: KernelConfig = { + kernelName: Dilation2D, + backendName: 'tensorflow', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter} = inputs as Dilation2DInputs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; + const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); + const $strides = [1, strideHeight, strideWidth, 1]; + const $dilations = [1, dilationHeight, dilationWidth, 1]; + + const nodeBackend = backend as NodeJSKernelBackend; + + const opAttrs = [ + createTensorsTypeOpAttr('T', x.dtype), + {name: 'strides', type: nodeBackend.binding.TF_ATTR_INT, value: $strides}, + {name: 'rates', type: nodeBackend.binding.TF_ATTR_INT, value: $dilations}, + { + name: 'padding', + type: nodeBackend.binding.TF_ATTR_STRING, + value: padInfo.type + } + ]; + + return nodeBackend.executeSingleOutput(Dilation2D, opAttrs, [x, filter]); + } +}; diff --git a/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts new file mode 100644 index 00000000000..7896e03cc90 --- /dev/null +++ b/tfjs-node/src/kernels/Dilation2DBackpropFilter.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropFilter, Dilation2DBackpropFilterInputs, KernelConfig} from '@tensorflow/tfjs'; + +import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +export const dilation2dBackpropFilterConfig: KernelConfig = { + kernelName: Dilation2DBackpropFilter, + backendName: 'tensorflow', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = inputs as Dilation2DBackpropFilterInputs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; + const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); + const $strides = [1, strideHeight, strideWidth, 1]; + const $dilations = [1, dilationHeight, dilationWidth, 1]; + + const nodeBackend = backend as NodeJSKernelBackend; + + const opAttrs = [ + createTensorsTypeOpAttr('T', x.dtype), + {name: 'strides', type: nodeBackend.binding.TF_ATTR_INT, value: $strides}, + {name: 'rates', type: nodeBackend.binding.TF_ATTR_INT, value: $dilations}, + { + name: 'padding', + type: nodeBackend.binding.TF_ATTR_STRING, + value: padInfo.type + } + ]; + + return nodeBackend.executeSingleOutput( + Dilation2DBackpropFilter, opAttrs, [x, filter, dy]); + } +}; diff --git a/tfjs-node/src/kernels/Dilation2DBackpropInput.ts b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts new file mode 100644 index 00000000000..61c931dd6d5 --- /dev/null +++ b/tfjs-node/src/kernels/Dilation2DBackpropInput.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, Dilation2DAttrs, Dilation2DBackpropInput, Dilation2DBackpropInputInputs, KernelConfig} from '@tensorflow/tfjs'; + +import {createTensorsTypeOpAttr, NodeJSKernelBackend} from '../nodejs_kernel_backend'; + +export const dilation2dBackpropInputConfig: KernelConfig = { + kernelName: Dilation2DBackpropInput, + backendName: 'tensorflow', + kernelFunc: ({inputs, backend, attrs}) => { + const {x, filter, dy} = inputs as Dilation2DBackpropInputInputs; + const {strides, pad, dilations} = attrs as {} as Dilation2DAttrs; + const {dilationHeight, dilationWidth, padInfo, strideHeight, strideWidth} = + backend_util.computeDilation2DInfo( + x.shape as [number, number, number, number], + filter.shape as [number, number, number], strides, pad, + 'NHWC' /* dataFormat */, dilations); + const $strides = [1, strideHeight, strideWidth, 1]; + const $dilations = [1, dilationHeight, dilationWidth, 1]; + + const nodeBackend = backend as NodeJSKernelBackend; + + const opAttrs = [ + createTensorsTypeOpAttr('T', x.dtype), + {name: 'strides', type: nodeBackend.binding.TF_ATTR_INT, value: $strides}, + {name: 'rates', type: nodeBackend.binding.TF_ATTR_INT, value: $dilations}, + { + name: 'padding', + type: nodeBackend.binding.TF_ATTR_STRING, + value: padInfo.type + } + ]; + + return nodeBackend.executeSingleOutput( + Dilation2DBackpropInput, opAttrs, [x, filter, dy]); + } +}; diff --git a/tfjs-node/src/register_all_kernels.ts b/tfjs-node/src/register_all_kernels.ts index 56ea721afa0..a9d5d0fc16b 100644 --- a/tfjs-node/src/register_all_kernels.ts +++ b/tfjs-node/src/register_all_kernels.ts @@ -19,13 +19,19 @@ // the contents of this file and import only the kernels that are needed. import {KernelConfig, registerKernel} from '@tensorflow/tfjs-core'; +import {dilation2dConfig} from './kernels/Dilation2D'; +import {dilation2dBackpropFilterConfig} from './kernels/Dilation2DBackpropFilter'; +import {dilation2dBackpropInputConfig} from './kernels/Dilation2DBackpropInput'; import {nonMaxSuppressionV5Config} from './kernels/NonMaxSuppressionV5'; import {softmaxConfig} from './kernels/Softmax'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; // List all kernel configs here -const kernelConfigs: KernelConfig[] = - [nonMaxSuppressionV5Config, softmaxConfig, squaredDifferenceConfig]; +const kernelConfigs: KernelConfig[] = [ + dilation2dConfig, dilation2dBackpropInputConfig, + dilation2dBackpropFilterConfig, nonMaxSuppressionV5Config, softmaxConfig, + squaredDifferenceConfig +]; for (const kernelConfig of kernelConfigs) { registerKernel(kernelConfig);