diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 48463084eb4..56d4e319d1b 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -219,6 +219,7 @@ tfjs_cc_library( ":Conv2DBackpropInput", ":CropAndResize", ":Cumsum", + ":DepthToSpace", ":DepthwiseConv2dNative", ":Div", ":Equal", @@ -422,6 +423,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "DepthToSpace", + srcs = ["kernels/DepthToSpace.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "DepthwiseConv2dNative", srcs = ["kernels/DepthwiseConv2dNative.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/DepthToSpace.cc b/tfjs-backend-wasm/src/cc/kernels/DepthToSpace.cc new file mode 100644 index 00000000000..884092a4ee6 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/DepthToSpace.cc @@ -0,0 +1,75 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void DepthToSpace(const size_t x_id, const size_t block_size, + const bool channels_last, const int32_t* x_strides_ptr, + const size_t x_strides_size, const int32_t* out_shape_ptr, + const int32_t* out_strides_ptr, const size_t out_shape_size, + const size_t out_id) { + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info_out(out_id); + const size_t out_size = out_info.size; + + const float* x_ptr = x_info.f32(); + float* out_buf_ptr = out_info.f32_write(); + + const auto x_strides = + std::vector(x_strides_ptr, x_strides_ptr + x_strides_size); + const auto out_shape = + std::vector(out_shape_ptr, out_shape_ptr + out_shape_size); + const auto out_strides = std::vector( + out_strides_ptr, out_strides_ptr + out_shape_size - 1); + + for (size_t i = 0; i < out_size; ++i) { + auto coords = tfjs::util::offset_to_loc(i, out_strides); + const size_t b = coords[0]; + const size_t h = channels_last ? coords[1] : coords[2]; + const size_t w = channels_last ? coords[2] : coords[3]; + const size_t d = channels_last ? coords[3] : coords[1]; + const size_t out_depth_size = channels_last ? out_shape[3] : out_shape[1]; + + const size_t in_h = h / block_size; + const size_t offset_h = h % block_size; + const size_t in_w = w / block_size; + const size_t offset_w = w % block_size; + const size_t offset_d = (offset_h * block_size + offset_w) * out_depth_size; + + const size_t in_d = d + offset_d; + + size_t x_index = + channels_last + ? tfjs::util::loc_to_offset({b, in_h, in_w, in_d}, x_strides) + : tfjs::util::loc_to_offset({b, in_d, in_h, in_w}, x_strides); + *out_buf_ptr = x_ptr[x_index]; + + out_buf_ptr++; + } +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/DepthToSpace.ts b/tfjs-backend-wasm/src/kernels/DepthToSpace.ts new file mode 100644 index 00000000000..bbcd24cd7b8 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/DepthToSpace.ts @@ -0,0 +1,92 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {DepthToSpace, DepthToSpaceAttrs, DepthToSpaceInputs, KernelConfig, KernelFunc, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +let wasmDepthToSpace: ( + xId: number, blockSize: number, channelsLast: number, xStrides: Uint8Array, + xStridesLength: number, outputShape: Uint8Array, outputStrides: Uint8Array, + outSize: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmDepthToSpace = backend.wasm.cwrap(DepthToSpace, null /*void*/, [ + 'number', // xId + 'number', // blockSize + 'number', // channelsLast + 'array', // xStrides + 'number', // xStridesLength + 'array', // outputShape + 'array', // outputStrides + 'number', // outSize + 'number', // outId + ]); +} + +export function depthToSpace(args: { + backend: BackendWasm, + inputs: DepthToSpaceInputs, + attrs: DepthToSpaceAttrs +}): TensorInfo { + const {backend, inputs, attrs} = args; + const {x} = inputs; + const {blockSize, dataFormat} = attrs; + + util.assert( + blockSize > 1, + () => `blockSize should be > 1 for depthToSpace, but was: ${blockSize}`); + + const batchSize = x.shape[0]; + const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; + const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; + const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; + + const outputHeight = inputHeight * blockSize; + const outputWidth = inputWidth * blockSize; + const outputDepth = inputDepth / (blockSize * blockSize); + + const outputShape = (dataFormat === 'NHWC') ? + [batchSize, outputHeight, outputWidth, outputDepth] : + [batchSize, outputDepth, outputHeight, outputWidth]; + + const out = backend.makeOutput(outputShape, 'float32'); + + const xData = backend.dataIdMap.get(x.dataId); + const xId = xData.id; + const xStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); + + const outputShapeBytes = new Uint8Array(new Int32Array(outputShape).buffer); + const outStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(outputShape)).buffer); + + const outId = backend.dataIdMap.get(out.dataId).id; + const channelsLast = dataFormat === 'NHWC' ? 1 : 0; + wasmDepthToSpace( + xId, blockSize, channelsLast, xStridesBytes, x.shape.length - 1, + outputShapeBytes, outStridesBytes, outputShape.length, outId); + + return out; +} + +export const depthToSpaceConfig: KernelConfig = { + kernelName: DepthToSpace, + backendName: 'wasm', + setupFunc: setup, + kernelFunc: depthToSpace as {} as KernelFunc +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 6da4734e223..671abc9d53b 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -34,6 +34,7 @@ import {conv2DBackpropInputConfig} from './kernels/Conv2DBackpropInput'; import {cosConfig} from './kernels/Cos'; import {cropAndResizeConfig} from './kernels/CropAndResize'; import {cumsumConfig} from './kernels/Cumsum'; +import {depthToSpaceConfig} from './kernels/DepthToSpace'; import {depthwiseConv2dNativeConfig} from './kernels/DepthwiseConv2dNative'; import {divConfig} from './kernels/Div'; import {equalConfig} from './kernels/Equal'; @@ -110,6 +111,7 @@ const kernelConfigs: KernelConfig[] = [ cosConfig, cropAndResizeConfig, cumsumConfig, + depthToSpaceConfig, depthwiseConv2dNativeConfig, divConfig, equalConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 17584ed4bc9..247a9a7ec37 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -52,6 +52,7 @@ const TEST_FILTERS: TestFilter[] = [ 'complex', // Complex numbers not supported yet ] }, + {include: 'depthToSpace'}, { include: 'avgPool', excludes: [