diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 56d4e319d1b..2fa8a501094 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -261,6 +261,7 @@ tfjs_cc_library( ":Softmax", ":Square", ":SquaredDifference", + ":StridedSlice", ":Sub", ":Tile", ":Transpose", @@ -917,6 +918,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "StridedSlice", + srcs = ["kernels/StridedSlice.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Sub", srcs = ["kernels/Sub.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc new file mode 100644 index 00000000000..0e1dc252a0a --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc @@ -0,0 +1,67 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void StridedSlice(const size_t x_id, const int32_t* x_strides_ptr, + const size_t x_rank, const int32_t* begin_ptr, + const int32_t* end_ptr, const int32_t* strides_ptr, + const int32_t* out_shape_ptr, const int32_t* out_strides_ptr, + const size_t out_shape_size, const size_t out_id) { + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info_out(out_id); + const size_t out_size = out_info.size; + + const float* x_ptr = x_info.f32(); + float* out_buf_ptr = out_info.f32_write(); + + const auto x_strides = + std::vector(x_strides_ptr, x_strides_ptr + x_rank - 1); + const auto begin = std::vector(begin_ptr, begin_ptr + x_rank); + const auto end = std::vector(end_ptr, end_ptr + x_rank); + const auto strides = std::vector(strides_ptr, strides_ptr + x_rank); + const auto out_shape = + std::vector(out_shape_ptr, out_shape_ptr + out_shape_size); + const auto out_strides = std::vector( + out_strides_ptr, out_strides_ptr + out_shape_size - 1); + + for (size_t i = 0; i < out_size; ++i) { + auto coords = tfjs::util::offset_to_loc(i, out_strides); + + std::vector new_loc = {}; + for (size_t j = 0; j < out_shape_size; ++j) { + new_loc.push_back(coords[j] * strides[j] + begin[j]); + } + + const size_t x_index = tfjs::util::loc_to_offset(new_loc, x_strides); + *out_buf_ptr = x_ptr[x_index]; + + out_buf_ptr++; + } +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts new file mode 100644 index 00000000000..bc55b8fa827 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -0,0 +1,146 @@ +/** + * @license + * Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {backend_util, KernelConfig, KernelFunc, StridedSlice, StridedSliceAttrs, StridedSliceInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; +import {reshape} from './Reshape'; +import {slice} from './Slice'; + +let wasmStridedSlice: ( + xId: number, xStridesBytes: Uint8Array, xRank: number, + beginBytes: Uint8Array, endBytes: Uint8Array, stridesBytes: Uint8Array, + outShapeBytes: Uint8Array, outStridesBytes: Uint8Array, + outShapeLength: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmStridedSlice = backend.wasm.cwrap(StridedSlice, null /*void*/, [ + 'number', // xId + 'array', // xStrides + 'number', // xRank + 'array', // beginBytes + 'array', // endBytes + 'array', // stridesBytes + 'array', // outShapeBytes + 'array', // outStridesBytes + 'number', // outShapeLength + 'number', // outId + ]); +} + +export function stridedSlice(args: { + backend: BackendWasm, + inputs: StridedSliceInputs, + attrs: StridedSliceAttrs +}): TensorInfo { + const {backend, inputs, attrs} = args; + const {x} = inputs; + + let {begin, end, strides} = attrs; + if (strides == null) { + strides = new Array(begin.length); + } + + const {beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask} = attrs; + + const ellipsisAxes = backend_util.slice_util.maskToAxes(ellipsisMask); + if (ellipsisAxes.length > 1) { + throw new Error('Multiple ellipses in slice is not allowed.'); + } + + if (ellipsisMask !== 0 && newAxisMask !== 0) { + throw new Error( + 'Using both ellipsisMask and newAxisMask is not yet supported.'); + } + + if (ellipsisMask !== 0 && shrinkAxisMask !== 0) { + throw new Error( + 'Using both ellipsisMask and shrinkAxisMask is not yet supported.'); + } + + const numInterpolatedAxes = x.shape.length - begin.length; + + // Expand the dims of x based on the newAxisMask. + const expandAxes = backend_util.slice_util.maskToAxes(newAxisMask); + const newShape = x.shape.slice(); + expandAxes.forEach(axis => { + begin[axis] = 0; + end[axis] = 1; + newShape.splice(axis, 0, 1); + }); + + const xReshaped = reshape({inputs: {x}, attrs: {shape: newShape}, backend}); + + const { + begin: normalizedBegin, + end: normalizedEnd, + strides: normalizedStrides + } = + backend_util.slice_util.getNormalizedAxes( + xReshaped.shape, ellipsisAxes, numInterpolatedAxes, begin, end, + strides, beginMask, endMask, ellipsisMask); + begin = normalizedBegin; + end = normalizedEnd; + strides = normalizedStrides; + + const shrinkAxes = backend_util.slice_util.maskToAxes(shrinkAxisMask); + // Adjust the ends based on the shrink mask. + shrinkAxes.forEach(axis => { + end[axis] = begin[axis] + 1; + strides[axis] = 1; + }); + + // Figure out the output shape. + const size = backend_util.slice_util.computeOutShape(begin, end, strides); + // Remove the axes based on shrinkMask. + const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1); + + const nonStrided = strides.every(v => v === 1); + if (nonStrided) { + const xSliced = slice({inputs: {x}, attrs: {begin, size}, backend}); + return reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend}); + } + + const out = backend.makeOutput(outShape, 'float32'); + if (!outShape.some(axis => axis === 0)) { + const xId = backend.dataIdMap.get(xReshaped.dataId).id; + const xStridesBytes = new Uint8Array( + new Int32Array(util.computeStrides(xReshaped.shape)).buffer); + const beginBytes = new Uint8Array(new Int32Array(begin).buffer); + const endBytes = new Uint8Array(new Int32Array(end).buffer); + const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); + + const outputShapeBytes = new Uint8Array(new Int32Array(outShape).buffer); + const outStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(outShape)).buffer); + const outId = backend.dataIdMap.get(out.dataId).id; + + wasmStridedSlice( + xId, xStridesBytes, xReshaped.shape.length, beginBytes, endBytes, + stridesBytes, outputShapeBytes, outStridesBytes, outShape.length, + outId); + } + + return reshape({inputs: {x: out}, attrs: {shape: outShape}, backend}); +} + +export const stridedSliceConfig: KernelConfig = { + kernelName: StridedSlice, + backendName: 'wasm', + setupFunc: setup, + kernelFunc: stridedSlice as {} as KernelFunc +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 671abc9d53b..7e4d70f6928 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -87,6 +87,7 @@ import {splitVConfig} from './kernels/Split'; import {sqrtConfig} from './kernels/Sqrt'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {stridedSliceConfig} from './kernels/StridedSlice'; import {subConfig} from './kernels/Sub'; import {sumConfig} from './kernels/Sum'; import {tanhConfig} from './kernels/Tanh'; @@ -165,6 +166,7 @@ const kernelConfigs: KernelConfig[] = [ sqrtConfig, squareConfig, squaredDifferenceConfig, + stridedSliceConfig, subConfig, sumConfig, tanhConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 247a9a7ec37..496973366cd 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -192,6 +192,7 @@ const TEST_FILTERS: TestFilter[] = [ ] }, {include: 'slice '}, + {include: 'stridedSlice '}, {include: 'rotate '}, {include: 'flipLeftRight '}, {include: 'square '}, diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index a5d8899353f..c18d8cac38f 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -36,6 +36,9 @@ export * from '../ops/fused_util'; export * from '../ops/fused_types'; export * from '../ops/reduce_util'; +import * as slice_util from '../ops/slice_util'; +export {slice_util}; + export {BackendValues, TypedArray, upcastType, PixelData} from '../types'; export {MemoryInfo, TimingInfo} from '../engine'; export * from '../ops/rotate_util'; diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 10799d2a57b..bd236038c96 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -102,6 +102,45 @@ function getElidedAxes(numElidedAxes: number, ellipsisInsertionIndex: number) { return elidedAxes; } +// Normalize the start, end and strides. +export function getNormalizedAxes( + inputShape: number[], ellipsisAxes: number[], numInterpolatedAxes: number, + begin: number[], end: number[], strides: number[], beginMask: number, + endMask: number, + ellipsisMask: number): {begin: number[], end: number[], strides: number[]} { + const inputRank = inputShape.length; + let normalizedBegin = new Array(inputRank), + normalizedEnd = new Array(inputRank), + normalizedStrides = new Array(inputRank); + if (ellipsisAxes.length && numInterpolatedAxes > 0) { + const fullIndex = ellipsisAxes[0]; + + // The ellipsis applies to the masked index as well as any dimensions + // that are interpolated. + const numElidedAxes = numInterpolatedAxes + 1; + normalizedBegin = startIndicesWithElidedDims( + beginMask, fullIndex, numElidedAxes, begin, inputShape); + normalizedEnd = stopIndicesWithElidedDims( + endMask, fullIndex, numElidedAxes, end, inputShape); + normalizedStrides = + stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape); + } else { + for (let axis = 0; axis < inputRank; axis++) { + normalizedBegin[axis] = startForAxis( + beginMask, begin, strides, inputShape, axis, ellipsisMask); + normalizedEnd[axis] = + stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask); + normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask); + } + } + + return { + begin: normalizedBegin, + end: normalizedEnd, + strides: normalizedStrides + }; +} + // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current start value. Otherwise, insert. export function startIndicesWithElidedDims( diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index 90e20171ba7..d99020c16ff 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -26,7 +26,7 @@ import {TensorLike} from '../types'; import {op} from './operation'; import {reshape} from './reshape'; import {slice} from './slice'; -import {computeOutShape, maskToAxes, startForAxis, startIndicesWithElidedDims, stopForAxis, stopIndicesWithElidedDims, stridesForAxis, stridesWithElidedDims} from './slice_util'; +import {computeOutShape, getNormalizedAxes, maskToAxes} from './slice_util'; /** * Extracts a strided slice of a tensor. @@ -65,13 +65,13 @@ function stridedSlice_( x: Tensor|TensorLike, begin: number[], end: number[], strides?: number[], beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0, shrinkAxisMask = 0): Tensor { - if (strides == null) { - strides = new Array(begin.length); - } - let $x = convertToTensor(x, 'x', 'stridedSlice'); const forward: ForwardFunc = (backend) => { + if (strides == null) { + strides = new Array(begin.length); + } + const ellipsisAxes = maskToAxes(ellipsisMask); if (ellipsisAxes.length > 1) { throw new Error('Multiple ellipses in slice is not allowed.'); @@ -99,28 +99,17 @@ function stridedSlice_( }); $x = reshape($x, newShape); - // Normalize the start, end and strides. - if (ellipsisAxes.length && numInterpolatedAxes > 0) { - const fullIndex = ellipsisAxes[0]; - - // The ellipsis applies to the masked index as well as any dimensions - // that are interpolated. - const numElidedAxes = numInterpolatedAxes + 1; - begin = startIndicesWithElidedDims( - beginMask, fullIndex, numElidedAxes, begin, $x.shape); - end = stopIndicesWithElidedDims( - endMask, fullIndex, numElidedAxes, end, $x.shape); - strides = - stridesWithElidedDims(strides, fullIndex, numElidedAxes, $x.shape); - } else { - for (let axis = 0; axis < $x.rank; axis++) { - begin[axis] = startForAxis( - beginMask, begin, strides, $x.shape, axis, ellipsisMask); - end[axis] = - stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); - strides[axis] = stridesForAxis(strides, axis, ellipsisMask); - } - } + const { + begin: normalizedBegin, + end: normalizedEnd, + strides: normalizedStrides + } = + getNormalizedAxes( + $x.shape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, + beginMask, endMask, ellipsisMask); + begin = normalizedBegin; + end = normalizedEnd; + strides = normalizedStrides; const shrinkAxes = maskToAxes(shrinkAxisMask); // Adjust the ends based on the shrink mask.