From 8ba8e5e6575418bca4ba7d81bbdeffaf90728ca0 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 9 Sep 2020 12:53:46 -0400 Subject: [PATCH 01/11] setup --- tfjs-backend-wasm/src/cc/BUILD | 10 ++ .../src/cc/kernels/StridedSlice.cc | 75 +++++++++++++ tfjs-backend-wasm/src/kernels/StridedSlice.ts | 101 ++++++++++++++++++ tfjs-backend-wasm/src/register_all_kernels.ts | 2 + tfjs-backend-wasm/src/setup_test.ts | 1 + 5 files changed, 189 insertions(+) create mode 100644 tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc create mode 100644 tfjs-backend-wasm/src/kernels/StridedSlice.ts diff --git a/tfjs-backend-wasm/src/cc/BUILD b/tfjs-backend-wasm/src/cc/BUILD index 56d4e319d1b..2fa8a501094 100644 --- a/tfjs-backend-wasm/src/cc/BUILD +++ b/tfjs-backend-wasm/src/cc/BUILD @@ -261,6 +261,7 @@ tfjs_cc_library( ":Softmax", ":Square", ":SquaredDifference", + ":StridedSlice", ":Sub", ":Tile", ":Transpose", @@ -917,6 +918,15 @@ tfjs_cc_library( ], ) +tfjs_cc_library( + name = "StridedSlice", + srcs = ["kernels/StridedSlice.cc"], + deps = [ + ":backend", + ":util", + ], +) + tfjs_cc_library( name = "Sub", srcs = ["kernels/Sub.cc"], diff --git a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc new file mode 100644 index 00000000000..272a7344f30 --- /dev/null +++ b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc @@ -0,0 +1,75 @@ +/* Copyright 2020 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ===========================================================================*/ + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include "src/cc/backend.h" +#include "src/cc/util.h" + +namespace tfjs { +namespace wasm { +extern "C" { +#ifdef __EMSCRIPTEN__ +EMSCRIPTEN_KEEPALIVE +#endif + +void StridedSlice(const size_t x_id, const size_t block_size, + const bool channels_last, const int32_t* x_strides_ptr, + const size_t x_strides_size, const int32_t* out_shape_ptr, + const int32_t* out_strides_ptr, const size_t out_shape_size, + const size_t out_id) { + auto& x_info = backend::get_tensor_info(x_id); + auto& out_info = backend::get_tensor_info_out(out_id); + const size_t out_size = out_info.size; + + const float* x_ptr = x_info.f32(); + float* out_buf_ptr = out_info.f32_write(); + + const auto x_strides = + std::vector(x_strides_ptr, x_strides_ptr + x_strides_size); + const auto out_shape = + std::vector(out_shape_ptr, out_shape_ptr + out_shape_size); + const auto out_strides = std::vector( + out_strides_ptr, out_strides_ptr + out_shape_size - 1); + + for (size_t i = 0; i < out_size; ++i) { + auto coords = tfjs::util::offset_to_loc(i, out_strides); + const size_t b = coords[0]; + const size_t h = channels_last ? coords[1] : coords[2]; + const size_t w = channels_last ? coords[2] : coords[3]; + const size_t d = channels_last ? coords[3] : coords[1]; + const size_t out_depth_size = channels_last ? out_shape[3] : out_shape[1]; + + const size_t in_h = h / block_size; + const size_t offset_h = h % block_size; + const size_t in_w = w / block_size; + const size_t offset_w = w % block_size; + const size_t offset_d = (offset_h * block_size + offset_w) * out_depth_size; + + const size_t in_d = d + offset_d; + + size_t x_index = + channels_last + ? tfjs::util::loc_to_offset({b, in_h, in_w, in_d}, x_strides) + : tfjs::util::loc_to_offset({b, in_d, in_h, in_w}, x_strides); + *out_buf_ptr = x_ptr[x_index]; + + out_buf_ptr++; + } +} +} // extern "C" +} // namespace wasm +} // namespace tfjs diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts new file mode 100644 index 00000000000..5c1df286163 --- /dev/null +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -0,0 +1,101 @@ +/** + * @license + * Copyright 2019 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {KernelConfig, KernelFunc, StridedSlice, StridedSliceAttrs, StridedSliceInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; + +import {BackendWasm} from '../backend_wasm'; + +let wasmStridedSlice: ( + xId: number, blockSize: number, channelsLast: number, xStrides: Uint8Array, + xStridesLength: number, outputShape: Uint8Array, outputStrides: Uint8Array, + outSize: number, outId: number) => void; + +function setup(backend: BackendWasm): void { + wasmStridedSlice = backend.wasm.cwrap(StridedSlice, null /*void*/, [ + 'number', // xId + 'number', // blockSize + 'number', // channelsLast + 'array', // xStrides + 'number', // xStridesLength + 'array', // outputShape + 'array', // outputStrides + 'number', // outSize + 'number', // outId + ]); +} + +export function stridedSlice(args: { + backend: BackendWasm, + inputs: StridedSliceInputs, + attrs: StridedSliceAttrs +}): TensorInfo { + const {backend, inputs, attrs} = args; + const {x} = inputs; + const { + begin, + end, + strides, + beginMask, + endMask, + ellipsisMask, + newAxisMask, + shrinkAxisMask + } = attrs; + + util.assert( + blockSize > 1, + () => `blockSize should be > 1 for stridedSlice, but was: ${blockSize}`); + + const batchSize = x.shape[0]; + const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; + const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; + const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; + + const outputHeight = inputHeight * blockSize; + const outputWidth = inputWidth * blockSize; + const outputDepth = inputDepth / (blockSize * blockSize); + + const outputShape = (dataFormat === 'NHWC') ? + [batchSize, outputHeight, outputWidth, outputDepth] : + [batchSize, outputDepth, outputHeight, outputWidth]; + + const out = backend.makeOutput(outputShape, 'float32'); + + const xData = backend.dataIdMap.get(x.dataId); + const xId = xData.id; + const xStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); + + const outputShapeBytes = new Uint8Array(new Int32Array(outputShape).buffer); + const outStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(outputShape)).buffer); + + const outId = backend.dataIdMap.get(out.dataId).id; + const channelsLast = dataFormat === 'NHWC' ? 1 : 0; + wasmStridedSlice( + xId, blockSize, channelsLast, xStridesBytes, x.shape.length - 1, + outputShapeBytes, outStridesBytes, outputShape.length, outId); + + return out; +} + +export const stridedSliceConfig: KernelConfig = { + kernelName: StridedSlice, + backendName: 'wasm', + setupFunc: setup, + kernelFunc: stridedSlice as {} as KernelFunc +}; diff --git a/tfjs-backend-wasm/src/register_all_kernels.ts b/tfjs-backend-wasm/src/register_all_kernels.ts index 671abc9d53b..7e4d70f6928 100644 --- a/tfjs-backend-wasm/src/register_all_kernels.ts +++ b/tfjs-backend-wasm/src/register_all_kernels.ts @@ -87,6 +87,7 @@ import {splitVConfig} from './kernels/Split'; import {sqrtConfig} from './kernels/Sqrt'; import {squareConfig} from './kernels/Square'; import {squaredDifferenceConfig} from './kernels/SquaredDifference'; +import {stridedSliceConfig} from './kernels/StridedSlice'; import {subConfig} from './kernels/Sub'; import {sumConfig} from './kernels/Sum'; import {tanhConfig} from './kernels/Tanh'; @@ -165,6 +166,7 @@ const kernelConfigs: KernelConfig[] = [ sqrtConfig, squareConfig, squaredDifferenceConfig, + stridedSliceConfig, subConfig, sumConfig, tanhConfig, diff --git a/tfjs-backend-wasm/src/setup_test.ts b/tfjs-backend-wasm/src/setup_test.ts index 247a9a7ec37..496973366cd 100644 --- a/tfjs-backend-wasm/src/setup_test.ts +++ b/tfjs-backend-wasm/src/setup_test.ts @@ -192,6 +192,7 @@ const TEST_FILTERS: TestFilter[] = [ ] }, {include: 'slice '}, + {include: 'stridedSlice '}, {include: 'rotate '}, {include: 'flipLeftRight '}, {include: 'square '}, From 6a10b05980ff54056c3423baf9bbff58b16b7e1c Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 9 Sep 2020 14:27:31 -0400 Subject: [PATCH 02/11] fix --- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 162 +++++++++++++----- tfjs-core/src/backends/backend_util.ts | 3 + 2 files changed, 118 insertions(+), 47 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index 5c1df286163..32c09d51ed4 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -15,9 +15,11 @@ * ============================================================================= */ -import {KernelConfig, KernelFunc, StridedSlice, StridedSliceAttrs, StridedSliceInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, KernelFunc, StridedSlice, StridedSliceAttrs, StridedSliceInputs, TensorInfo} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; +import {reshape} from './Reshape'; +import {slice} from './Slice'; let wasmStridedSlice: ( xId: number, blockSize: number, channelsLast: number, xStrides: Uint8Array, @@ -45,52 +47,118 @@ export function stridedSlice(args: { }): TensorInfo { const {backend, inputs, attrs} = args; const {x} = inputs; - const { - begin, - end, - strides, - beginMask, - endMask, - ellipsisMask, - newAxisMask, - shrinkAxisMask - } = attrs; - - util.assert( - blockSize > 1, - () => `blockSize should be > 1 for stridedSlice, but was: ${blockSize}`); - - const batchSize = x.shape[0]; - const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; - const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; - const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; - - const outputHeight = inputHeight * blockSize; - const outputWidth = inputWidth * blockSize; - const outputDepth = inputDepth / (blockSize * blockSize); - - const outputShape = (dataFormat === 'NHWC') ? - [batchSize, outputHeight, outputWidth, outputDepth] : - [batchSize, outputDepth, outputHeight, outputWidth]; - - const out = backend.makeOutput(outputShape, 'float32'); - - const xData = backend.dataIdMap.get(x.dataId); - const xId = xData.id; - const xStridesBytes = - new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); - - const outputShapeBytes = new Uint8Array(new Int32Array(outputShape).buffer); - const outStridesBytes = - new Uint8Array(new Int32Array(util.computeStrides(outputShape)).buffer); - - const outId = backend.dataIdMap.get(out.dataId).id; - const channelsLast = dataFormat === 'NHWC' ? 1 : 0; - wasmStridedSlice( - xId, blockSize, channelsLast, xStridesBytes, x.shape.length - 1, - outputShapeBytes, outStridesBytes, outputShape.length, outId); - - return out; + + let {begin, end, strides} = attrs; + const {beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask} = attrs; + + const ellipsisAxes = backend_util.slice_util.maskToAxes(ellipsisMask); + if (ellipsisAxes.length > 1) { + throw new Error('Multiple ellipses in slice is not allowed.'); + } + + if (ellipsisMask !== 0 && newAxisMask !== 0) { + throw new Error( + 'Using both ellipsisMask and newAxisMask is not yet supported.'); + } + + if (ellipsisMask !== 0 && shrinkAxisMask !== 0) { + throw new Error( + 'Using both ellipsisMask and shrinkAxisMask is not yet supported.'); + } + + const numInterpolatedAxes = x.shape.length - begin.length; + + // Expand the dims of x based on the newAxisMask. + const expandAxes = backend_util.slice_util.maskToAxes(newAxisMask); + const newShape = x.shape.slice(); + expandAxes.forEach(axis => { + begin[axis] = 0; + end[axis] = 1; + newShape.splice(axis, 0, 1); + }); + + const xReshaped = reshape({inputs: {x}, attrs: {shape: newShape}, backend}); + + // Normalize the start, end and strides. + if (ellipsisAxes.length && numInterpolatedAxes > 0) { + const fullIndex = ellipsisAxes[0]; + + // The ellipsis applies to the masked index as well as any dimensions + // that are interpolated. + const numElidedAxes = numInterpolatedAxes + 1; + begin = backend_util.slice_util.startIndicesWithElidedDims( + beginMask, fullIndex, numElidedAxes, begin, xReshaped.shape); + end = backend_util.slice_util.stopIndicesWithElidedDims( + endMask, fullIndex, numElidedAxes, end, xReshaped.shape); + strides = backend_util.slice_util.stridesWithElidedDims( + strides, fullIndex, numElidedAxes, xReshaped.shape); + } else { + for (let axis = 0; axis < xReshaped.shape.length; axis++) { + begin[axis] = backend_util.slice_util.startForAxis( + beginMask, begin, strides, xReshaped.shape, axis, ellipsisMask); + end[axis] = backend_util.slice_util.stopForAxis( + endMask, end, strides, xReshaped.shape, axis, ellipsisMask); + strides[axis] = + backend_util.slice_util.stridesForAxis(strides, axis, ellipsisMask); + } + } + + const shrinkAxes = backend_util.slice_util.maskToAxes(shrinkAxisMask); + // Adjust the ends based on the shrink mask. + shrinkAxes.forEach(axis => { + end[axis] = begin[axis] + 1; + strides[axis] = 1; + }); + + // Figure out the output shape. + const size = backend_util.slice_util.computeOutShape(begin, end, strides); + // Remove the axes based on shrinkMask. + const outShape = size.filter((_, axis) => shrinkAxes.indexOf(axis) === -1); + + const nonStrided = strides.every(v => v === 1); + if (nonStrided) { + const xSliced = slice({inputs: {x}, attrs: {begin, size}, backend}); + return reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend}); + } + + const out = backend.makeOutput(outShape, 'float32'); + + wasmStridedSlice(xReshaped, begin, end, strides); + + return reshape({inputs: {x: out}, attrs: {shape: outShape}, backend}); + + // const batchSize = x.shape[0]; + // const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; + // const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; + // const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; + + // const outputHeight = inputHeight * blockSize; + // const outputWidth = inputWidth * blockSize; + // const outputDepth = inputDepth / (blockSize * blockSize); + + // const outputShape = (dataFormat === 'NHWC') ? + // [batchSize, outputHeight, outputWidth, outputDepth] : + // [batchSize, outputDepth, outputHeight, outputWidth]; + + // const out = backend.makeOutput(outputShape, 'float32'); + + // const xData = backend.dataIdMap.get(x.dataId); + // const xId = xData.id; + // const xStridesBytes = + // new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); + + // const outputShapeBytes = new Uint8Array(new + // Int32Array(outputShape).buffer); const outStridesBytes = + // new Uint8Array(new + // Int32Array(util.computeStrides(outputShape)).buffer); + + // const outId = backend.dataIdMap.get(out.dataId).id; + // const channelsLast = dataFormat === 'NHWC' ? 1 : 0; + // wasmStridedSlice( + // xId, blockSize, channelsLast, xStridesBytes, x.shape.length - 1, + // outputShapeBytes, outStridesBytes, outputShape.length, outId); + + // return out; } export const stridedSliceConfig: KernelConfig = { diff --git a/tfjs-core/src/backends/backend_util.ts b/tfjs-core/src/backends/backend_util.ts index a5d8899353f..c18d8cac38f 100644 --- a/tfjs-core/src/backends/backend_util.ts +++ b/tfjs-core/src/backends/backend_util.ts @@ -36,6 +36,9 @@ export * from '../ops/fused_util'; export * from '../ops/fused_types'; export * from '../ops/reduce_util'; +import * as slice_util from '../ops/slice_util'; +export {slice_util}; + export {BackendValues, TypedArray, upcastType, PixelData} from '../types'; export {MemoryInfo, TimingInfo} from '../engine'; export * from '../ops/rotate_util'; From 62636daf837e4f363f90f05c6f81ad0798b03cf4 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 9 Sep 2020 16:38:21 -0400 Subject: [PATCH 03/11] fix --- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 38 +++++++++++++------ 1 file changed, 27 insertions(+), 11 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index 32c09d51ed4..4ca816cf4de 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -22,20 +22,26 @@ import {reshape} from './Reshape'; import {slice} from './Slice'; let wasmStridedSlice: ( - xId: number, blockSize: number, channelsLast: number, xStrides: Uint8Array, - xStridesLength: number, outputShape: Uint8Array, outputStrides: Uint8Array, - outSize: number, outId: number) => void; + xId: number, beginBytes: Uint8Array, beginLength: number, + endBytes: Uint8Array, endLength: number, stridesBytes: Uint8Array, + stridesLength: number, beginMask: number, endMask: number, + ellipsisMask: number, newAxisMask: number, shrinkAxisMask: number, + outId: number) => void; function setup(backend: BackendWasm): void { wasmStridedSlice = backend.wasm.cwrap(StridedSlice, null /*void*/, [ 'number', // xId - 'number', // blockSize - 'number', // channelsLast - 'array', // xStrides - 'number', // xStridesLength - 'array', // outputShape - 'array', // outputStrides - 'number', // outSize + 'array', // beginBytes + 'number', // beginLength + 'array', // endBytes + 'number', // endLength + 'array', // stridesBytes + 'number', // stridesLength + 'number', // beginMask + 'number', // endMask + 'number', // ellipsisMask + 'number', // newAxisMask + 'number', // shrinkAxisMask 'number', // outId ]); } @@ -121,9 +127,19 @@ export function stridedSlice(args: { return reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend}); } + const xId = backend.dataIdMap.get(xReshaped.dataId); + const out = backend.makeOutput(outShape, 'float32'); + const outId = backend.dataIdMap.get(out.dataId).id; + + const beginBytes = new Uint8Array(new Int32Array(begin).buffer); + const endBytes = new Uint8Array(new Int32Array(end).buffer); + const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); - wasmStridedSlice(xReshaped, begin, end, strides); + wasmStridedSlice( + xId, beginBytes, begin.length, endBytes, end.length, stridesBytes, + strides.length, beginMask, endMask, ellipsisMask, newAxisMask, + shrinkAxisMask, outId); return reshape({inputs: {x: out}, attrs: {shape: outShape}, backend}); From 4ff38b264a67154a5e37c8407067ff40a6b1390d Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 9 Sep 2020 16:42:01 -0400 Subject: [PATCH 04/11] fix interface --- tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc index 272a7344f30..ab2a6fdfe0e 100644 --- a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc +++ b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc @@ -26,10 +26,12 @@ extern "C" { EMSCRIPTEN_KEEPALIVE #endif -void StridedSlice(const size_t x_id, const size_t block_size, - const bool channels_last, const int32_t* x_strides_ptr, - const size_t x_strides_size, const int32_t* out_shape_ptr, - const int32_t* out_strides_ptr, const size_t out_shape_size, +void StridedSlice(const size_t x_id, const int32_t* begin_ptr, + const size_t begin_size, const int32_t* end_ptr, + const size_t end_size, const int32_t* strides_ptr, + const size_t strides_size, const size_t begin_mask, + const size_t end_mask, const size_t ellipsis_mask, + const size_t new_axis_mask, const size_t shrink_axis_mask, const size_t out_id) { auto& x_info = backend::get_tensor_info(x_id); auto& out_info = backend::get_tensor_info_out(out_id); From cf258df971aebf4fd22c7b9dd3766f5043db31f8 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Wed, 9 Sep 2020 16:45:41 -0400 Subject: [PATCH 05/11] fix --- .../src/cc/kernels/StridedSlice.cc | 47 ++++++++++--------- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 35 +------------- 2 files changed, 25 insertions(+), 57 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc index ab2a6fdfe0e..f27100d80e3 100644 --- a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc +++ b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc @@ -40,34 +40,35 @@ void StridedSlice(const size_t x_id, const int32_t* begin_ptr, const float* x_ptr = x_info.f32(); float* out_buf_ptr = out_info.f32_write(); - const auto x_strides = - std::vector(x_strides_ptr, x_strides_ptr + x_strides_size); - const auto out_shape = - std::vector(out_shape_ptr, out_shape_ptr + out_shape_size); - const auto out_strides = std::vector( - out_strides_ptr, out_strides_ptr + out_shape_size - 1); + const auto begin = std::vector(begin_ptr, begin_ptr + begin_size); + const auto end = std::vector(end_ptr, end_ptr + end_size); + const auto strides = + std::vector(strides_ptr, strides_ptr + strides_size); for (size_t i = 0; i < out_size; ++i) { - auto coords = tfjs::util::offset_to_loc(i, out_strides); - const size_t b = coords[0]; - const size_t h = channels_last ? coords[1] : coords[2]; - const size_t w = channels_last ? coords[2] : coords[3]; - const size_t d = channels_last ? coords[3] : coords[1]; - const size_t out_depth_size = channels_last ? out_shape[3] : out_shape[1]; + // auto coords = tfjs::util::offset_to_loc(i, out_strides); + // const size_t b = coords[0]; + // const size_t h = channels_last ? coords[1] : coords[2]; + // const size_t w = channels_last ? coords[2] : coords[3]; + // const size_t d = channels_last ? coords[3] : coords[1]; + // const size_t out_depth_size = channels_last ? out_shape[3] : + // out_shape[1]; - const size_t in_h = h / block_size; - const size_t offset_h = h % block_size; - const size_t in_w = w / block_size; - const size_t offset_w = w % block_size; - const size_t offset_d = (offset_h * block_size + offset_w) * out_depth_size; + // const size_t in_h = h / block_size; + // const size_t offset_h = h % block_size; + // const size_t in_w = w / block_size; + // const size_t offset_w = w % block_size; + // const size_t offset_d = (offset_h * block_size + offset_w) * + // out_depth_size; - const size_t in_d = d + offset_d; + // const size_t in_d = d + offset_d; - size_t x_index = - channels_last - ? tfjs::util::loc_to_offset({b, in_h, in_w, in_d}, x_strides) - : tfjs::util::loc_to_offset({b, in_d, in_h, in_w}, x_strides); - *out_buf_ptr = x_ptr[x_index]; + // size_t x_index = + // channels_last + // ? tfjs::util::loc_to_offset({b, in_h, in_w, in_d}, x_strides) + // : tfjs::util::loc_to_offset({b, in_d, in_h, in_w}, x_strides); + + // *out_buf_ptr = x_ptr[x_index]; out_buf_ptr++; } diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index 4ca816cf4de..119cd0c8114 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -127,7 +127,7 @@ export function stridedSlice(args: { return reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend}); } - const xId = backend.dataIdMap.get(xReshaped.dataId); + const xId = backend.dataIdMap.get(xReshaped.dataId).id; const out = backend.makeOutput(outShape, 'float32'); const outId = backend.dataIdMap.get(out.dataId).id; @@ -142,39 +142,6 @@ export function stridedSlice(args: { shrinkAxisMask, outId); return reshape({inputs: {x: out}, attrs: {shape: outShape}, backend}); - - // const batchSize = x.shape[0]; - // const inputHeight = (dataFormat === 'NHWC') ? x.shape[1] : x.shape[2]; - // const inputWidth = (dataFormat === 'NHWC') ? x.shape[2] : x.shape[3]; - // const inputDepth = (dataFormat === 'NHWC') ? x.shape[3] : x.shape[1]; - - // const outputHeight = inputHeight * blockSize; - // const outputWidth = inputWidth * blockSize; - // const outputDepth = inputDepth / (blockSize * blockSize); - - // const outputShape = (dataFormat === 'NHWC') ? - // [batchSize, outputHeight, outputWidth, outputDepth] : - // [batchSize, outputDepth, outputHeight, outputWidth]; - - // const out = backend.makeOutput(outputShape, 'float32'); - - // const xData = backend.dataIdMap.get(x.dataId); - // const xId = xData.id; - // const xStridesBytes = - // new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); - - // const outputShapeBytes = new Uint8Array(new - // Int32Array(outputShape).buffer); const outStridesBytes = - // new Uint8Array(new - // Int32Array(util.computeStrides(outputShape)).buffer); - - // const outId = backend.dataIdMap.get(out.dataId).id; - // const channelsLast = dataFormat === 'NHWC' ? 1 : 0; - // wasmStridedSlice( - // xId, blockSize, channelsLast, xStridesBytes, x.shape.length - 1, - // outputShapeBytes, outStridesBytes, outputShape.length, outId); - - // return out; } export const stridedSliceConfig: KernelConfig = { From f42ad8fa123ae97376d988ddfdd8f912bd998584 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 10 Sep 2020 09:47:00 -0400 Subject: [PATCH 06/11] basic --- .../src/cc/kernels/StridedSlice.cc | 42 ++++++++--------- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 45 +++++++++++++------ 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc index f27100d80e3..3f91dd0b3f7 100644 --- a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc +++ b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc @@ -26,13 +26,15 @@ extern "C" { EMSCRIPTEN_KEEPALIVE #endif -void StridedSlice(const size_t x_id, const int32_t* begin_ptr, +void StridedSlice(const size_t x_id, const int32_t* x_strides_ptr, + const size_t x_strides_size, const int32_t* begin_ptr, const size_t begin_size, const int32_t* end_ptr, const size_t end_size, const int32_t* strides_ptr, const size_t strides_size, const size_t begin_mask, const size_t end_mask, const size_t ellipsis_mask, const size_t new_axis_mask, const size_t shrink_axis_mask, - const size_t out_id) { + const int32_t* out_shape_ptr, const int32_t* out_strides_ptr, + const size_t out_shape_size, const size_t out_id) { auto& x_info = backend::get_tensor_info(x_id); auto& out_info = backend::get_tensor_info_out(out_id); const size_t out_size = out_info.size; @@ -40,35 +42,29 @@ void StridedSlice(const size_t x_id, const int32_t* begin_ptr, const float* x_ptr = x_info.f32(); float* out_buf_ptr = out_info.f32_write(); + const auto x_strides = + std::vector(x_strides_ptr, x_strides_ptr + x_strides_size); + const auto begin = std::vector(begin_ptr, begin_ptr + begin_size); const auto end = std::vector(end_ptr, end_ptr + end_size); const auto strides = std::vector(strides_ptr, strides_ptr + strides_size); - for (size_t i = 0; i < out_size; ++i) { - // auto coords = tfjs::util::offset_to_loc(i, out_strides); - // const size_t b = coords[0]; - // const size_t h = channels_last ? coords[1] : coords[2]; - // const size_t w = channels_last ? coords[2] : coords[3]; - // const size_t d = channels_last ? coords[3] : coords[1]; - // const size_t out_depth_size = channels_last ? out_shape[3] : - // out_shape[1]; - - // const size_t in_h = h / block_size; - // const size_t offset_h = h % block_size; - // const size_t in_w = w / block_size; - // const size_t offset_w = w % block_size; - // const size_t offset_d = (offset_h * block_size + offset_w) * - // out_depth_size; + const auto out_shape = + std::vector(out_shape_ptr, out_shape_ptr + out_shape_size); + const auto out_strides = std::vector( + out_strides_ptr, out_strides_ptr + out_shape_size - 1); - // const size_t in_d = d + offset_d; + for (size_t i = 0; i < out_size; ++i) { + auto coords = tfjs::util::offset_to_loc(i, out_strides); - // size_t x_index = - // channels_last - // ? tfjs::util::loc_to_offset({b, in_h, in_w, in_d}, x_strides) - // : tfjs::util::loc_to_offset({b, in_d, in_h, in_w}, x_strides); + std::vector new_loc = {}; + for (size_t j = 0; j < out_shape_size; ++j) { + new_loc.push_back(coords[j] * strides[j] + begin[j]); + } - // *out_buf_ptr = x_ptr[x_index]; + const size_t x_index = tfjs::util::loc_to_offset(new_loc, x_strides); + *out_buf_ptr = x_ptr[x_index]; out_buf_ptr++; } diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index 119cd0c8114..3975a0c7275 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -15,22 +15,25 @@ * ============================================================================= */ -import {backend_util, KernelConfig, KernelFunc, StridedSlice, StridedSliceAttrs, StridedSliceInputs, TensorInfo} from '@tensorflow/tfjs-core'; +import {backend_util, KernelConfig, KernelFunc, StridedSlice, StridedSliceAttrs, StridedSliceInputs, TensorInfo, util} from '@tensorflow/tfjs-core'; import {BackendWasm} from '../backend_wasm'; import {reshape} from './Reshape'; import {slice} from './Slice'; let wasmStridedSlice: ( - xId: number, beginBytes: Uint8Array, beginLength: number, - endBytes: Uint8Array, endLength: number, stridesBytes: Uint8Array, - stridesLength: number, beginMask: number, endMask: number, - ellipsisMask: number, newAxisMask: number, shrinkAxisMask: number, - outId: number) => void; + xId: number, xStridesBytes: Uint8Array, xStridesLength: number, + beginBytes: Uint8Array, beginLength: number, endBytes: Uint8Array, + endLength: number, stridesBytes: Uint8Array, stridesLength: number, + beginMask: number, endMask: number, ellipsisMask: number, + newAxisMask: number, shrinkAxisMask: number, outShapeBytes: Uint8Array, + outStridesBytes: Uint8Array, outShapeLength: number, outId: number) => void; function setup(backend: BackendWasm): void { wasmStridedSlice = backend.wasm.cwrap(StridedSlice, null /*void*/, [ 'number', // xId + 'array', // xStrides + 'number', // xStridesLength 'array', // beginBytes 'number', // beginLength 'array', // endBytes @@ -42,6 +45,9 @@ function setup(backend: BackendWasm): void { 'number', // ellipsisMask 'number', // newAxisMask 'number', // shrinkAxisMask + 'array', // outShapeBytes + 'array', // outStridesBytes + 'number', // outShapeLength 'number', // outId ]); } @@ -130,16 +136,27 @@ export function stridedSlice(args: { const xId = backend.dataIdMap.get(xReshaped.dataId).id; const out = backend.makeOutput(outShape, 'float32'); - const outId = backend.dataIdMap.get(out.dataId).id; - const beginBytes = new Uint8Array(new Int32Array(begin).buffer); - const endBytes = new Uint8Array(new Int32Array(end).buffer); - const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); + if (!outShape.some(axis => axis === 0)) { + const outId = backend.dataIdMap.get(out.dataId).id; - wasmStridedSlice( - xId, beginBytes, begin.length, endBytes, end.length, stridesBytes, - strides.length, beginMask, endMask, ellipsisMask, newAxisMask, - shrinkAxisMask, outId); + const beginBytes = new Uint8Array(new Int32Array(begin).buffer); + const endBytes = new Uint8Array(new Int32Array(end).buffer); + const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); + + const xStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); + + const outputShapeBytes = new Uint8Array(new Int32Array(outShape).buffer); + const outStridesBytes = + new Uint8Array(new Int32Array(util.computeStrides(outShape)).buffer); + + wasmStridedSlice( + xId, xStridesBytes, x.shape.length - 1, beginBytes, begin.length, + endBytes, end.length, stridesBytes, strides.length, beginMask, endMask, + ellipsisMask, newAxisMask, shrinkAxisMask, outputShapeBytes, + outStridesBytes, outShape.length, outId); + } return reshape({inputs: {x: out}, attrs: {shape: outShape}, backend}); } From 542a5e339d9b6fe4745ffc37a4c216910fade036 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 10 Sep 2020 10:09:08 -0400 Subject: [PATCH 07/11] clean --- .../src/cc/kernels/StridedSlice.cc | 17 +++++------- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 27 ++++++------------- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc index 3f91dd0b3f7..76daf7ad617 100644 --- a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc +++ b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc @@ -27,12 +27,8 @@ EMSCRIPTEN_KEEPALIVE #endif void StridedSlice(const size_t x_id, const int32_t* x_strides_ptr, - const size_t x_strides_size, const int32_t* begin_ptr, - const size_t begin_size, const int32_t* end_ptr, - const size_t end_size, const int32_t* strides_ptr, - const size_t strides_size, const size_t begin_mask, - const size_t end_mask, const size_t ellipsis_mask, - const size_t new_axis_mask, const size_t shrink_axis_mask, + const size_t x_rank, const int32_t* begin_ptr, + const int32_t* end_ptr, const int32_t* strides_ptr, const int32_t* out_shape_ptr, const int32_t* out_strides_ptr, const size_t out_shape_size, const size_t out_id) { auto& x_info = backend::get_tensor_info(x_id); @@ -43,12 +39,11 @@ void StridedSlice(const size_t x_id, const int32_t* x_strides_ptr, float* out_buf_ptr = out_info.f32_write(); const auto x_strides = - std::vector(x_strides_ptr, x_strides_ptr + x_strides_size); + std::vector(x_strides_ptr, x_strides_ptr + x_rank - 1); - const auto begin = std::vector(begin_ptr, begin_ptr + begin_size); - const auto end = std::vector(end_ptr, end_ptr + end_size); - const auto strides = - std::vector(strides_ptr, strides_ptr + strides_size); + const auto begin = std::vector(begin_ptr, begin_ptr + x_rank); + const auto end = std::vector(end_ptr, end_ptr + x_rank); + const auto strides = std::vector(strides_ptr, strides_ptr + x_rank); const auto out_shape = std::vector(out_shape_ptr, out_shape_ptr + out_shape_size); diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index 3975a0c7275..ae08b09875c 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -22,29 +22,19 @@ import {reshape} from './Reshape'; import {slice} from './Slice'; let wasmStridedSlice: ( - xId: number, xStridesBytes: Uint8Array, xStridesLength: number, - beginBytes: Uint8Array, beginLength: number, endBytes: Uint8Array, - endLength: number, stridesBytes: Uint8Array, stridesLength: number, - beginMask: number, endMask: number, ellipsisMask: number, - newAxisMask: number, shrinkAxisMask: number, outShapeBytes: Uint8Array, - outStridesBytes: Uint8Array, outShapeLength: number, outId: number) => void; + xId: number, xStridesBytes: Uint8Array, xRank: number, + beginBytes: Uint8Array, endBytes: Uint8Array, stridesBytes: Uint8Array, + outShapeBytes: Uint8Array, outStridesBytes: Uint8Array, + outShapeLength: number, outId: number) => void; function setup(backend: BackendWasm): void { wasmStridedSlice = backend.wasm.cwrap(StridedSlice, null /*void*/, [ 'number', // xId 'array', // xStrides - 'number', // xStridesLength + 'number', // xRank 'array', // beginBytes - 'number', // beginLength 'array', // endBytes - 'number', // endLength 'array', // stridesBytes - 'number', // stridesLength - 'number', // beginMask - 'number', // endMask - 'number', // ellipsisMask - 'number', // newAxisMask - 'number', // shrinkAxisMask 'array', // outShapeBytes 'array', // outStridesBytes 'number', // outShapeLength @@ -152,10 +142,9 @@ export function stridedSlice(args: { new Uint8Array(new Int32Array(util.computeStrides(outShape)).buffer); wasmStridedSlice( - xId, xStridesBytes, x.shape.length - 1, beginBytes, begin.length, - endBytes, end.length, stridesBytes, strides.length, beginMask, endMask, - ellipsisMask, newAxisMask, shrinkAxisMask, outputShapeBytes, - outStridesBytes, outShape.length, outId); + xId, xStridesBytes, xReshaped.shape.length, beginBytes, endBytes, + stridesBytes, outputShapeBytes, outStridesBytes, outShape.length, + outId); } return reshape({inputs: {x: out}, attrs: {shape: outShape}, backend}); From 930d19de5ec24de316dedaf5347a5c591a24c214 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 10 Sep 2020 10:22:12 -0400 Subject: [PATCH 08/11] clean --- tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc | 2 -- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 12 ++++-------- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc index 76daf7ad617..0e1dc252a0a 100644 --- a/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc +++ b/tfjs-backend-wasm/src/cc/kernels/StridedSlice.cc @@ -40,11 +40,9 @@ void StridedSlice(const size_t x_id, const int32_t* x_strides_ptr, const auto x_strides = std::vector(x_strides_ptr, x_strides_ptr + x_rank - 1); - const auto begin = std::vector(begin_ptr, begin_ptr + x_rank); const auto end = std::vector(end_ptr, end_ptr + x_rank); const auto strides = std::vector(strides_ptr, strides_ptr + x_rank); - const auto out_shape = std::vector(out_shape_ptr, out_shape_ptr + out_shape_size); const auto out_strides = std::vector( diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index ae08b09875c..da7e0251a1e 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -123,23 +123,19 @@ export function stridedSlice(args: { return reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend}); } - const xId = backend.dataIdMap.get(xReshaped.dataId).id; - const out = backend.makeOutput(outShape, 'float32'); - if (!outShape.some(axis => axis === 0)) { - const outId = backend.dataIdMap.get(out.dataId).id; - + const xId = backend.dataIdMap.get(xReshaped.dataId).id; + const xStridesBytes = new Uint8Array( + new Int32Array(util.computeStrides(xReshaped.shape)).buffer); const beginBytes = new Uint8Array(new Int32Array(begin).buffer); const endBytes = new Uint8Array(new Int32Array(end).buffer); const stridesBytes = new Uint8Array(new Int32Array(strides).buffer); - const xStridesBytes = - new Uint8Array(new Int32Array(util.computeStrides(x.shape)).buffer); - const outputShapeBytes = new Uint8Array(new Int32Array(outShape).buffer); const outStridesBytes = new Uint8Array(new Int32Array(util.computeStrides(outShape)).buffer); + const outId = backend.dataIdMap.get(out.dataId).id; wasmStridedSlice( xId, xStridesBytes, xReshaped.shape.length, beginBytes, endBytes, From c08b20e6535d96d9709e9b651e468bc3ece8cc63 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 10 Sep 2020 10:41:27 -0400 Subject: [PATCH 09/11] clean --- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 34 ++++++----------- tfjs-core/src/ops/slice_util.ts | 38 +++++++++++++++++++ tfjs-core/src/ops/strided_slice.ts | 36 +++++++----------- 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index da7e0251a1e..c6c620c9aa1 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -81,29 +81,17 @@ export function stridedSlice(args: { const xReshaped = reshape({inputs: {x}, attrs: {shape: newShape}, backend}); - // Normalize the start, end and strides. - if (ellipsisAxes.length && numInterpolatedAxes > 0) { - const fullIndex = ellipsisAxes[0]; - - // The ellipsis applies to the masked index as well as any dimensions - // that are interpolated. - const numElidedAxes = numInterpolatedAxes + 1; - begin = backend_util.slice_util.startIndicesWithElidedDims( - beginMask, fullIndex, numElidedAxes, begin, xReshaped.shape); - end = backend_util.slice_util.stopIndicesWithElidedDims( - endMask, fullIndex, numElidedAxes, end, xReshaped.shape); - strides = backend_util.slice_util.stridesWithElidedDims( - strides, fullIndex, numElidedAxes, xReshaped.shape); - } else { - for (let axis = 0; axis < xReshaped.shape.length; axis++) { - begin[axis] = backend_util.slice_util.startForAxis( - beginMask, begin, strides, xReshaped.shape, axis, ellipsisMask); - end[axis] = backend_util.slice_util.stopForAxis( - endMask, end, strides, xReshaped.shape, axis, ellipsisMask); - strides[axis] = - backend_util.slice_util.stridesForAxis(strides, axis, ellipsisMask); - } - } + const { + begin: normalizedBegin, + end: normalizedEnd, + strides: normalizedStrides + } = + backend_util.slice_util.getNormalizedAxes( + xReshaped.shape, ellipsisAxes, numInterpolatedAxes, begin, end, + strides, beginMask, endMask, ellipsisMask); + begin = normalizedBegin; + end = normalizedEnd; + strides = normalizedStrides; const shrinkAxes = backend_util.slice_util.maskToAxes(shrinkAxisMask); // Adjust the ends based on the shrink mask. diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 10799d2a57b..8e608d585c1 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -102,6 +102,44 @@ function getElidedAxes(numElidedAxes: number, ellipsisInsertionIndex: number) { return elidedAxes; } +// Normalize the start, end and strides. +export function getNormalizedAxes( + inputShape: number[], ellipsisAxes: number[], numInterpolatedAxes: number, + begin: number[], end: number[], strides: number[], beginMask: number, + endMask: number, ellipsisMask: number) { + const inputRank = inputShape.length; + let normalizedBegin = new Array(inputRank), + normalizedEnd = new Array(inputRank), + normalizedStrides = new Array(inputRank); + if (ellipsisAxes.length && numInterpolatedAxes > 0) { + const fullIndex = ellipsisAxes[0]; + + // The ellipsis applies to the masked index as well as any dimensions + // that are interpolated. + const numElidedAxes = numInterpolatedAxes + 1; + normalizedBegin = startIndicesWithElidedDims( + beginMask, fullIndex, numElidedAxes, begin, inputShape); + normalizedEnd = stopIndicesWithElidedDims( + endMask, fullIndex, numElidedAxes, end, inputShape); + normalizedStrides = + stridesWithElidedDims(strides, fullIndex, numElidedAxes, inputShape); + } else { + for (let axis = 0; axis < inputRank; axis++) { + normalizedBegin[axis] = startForAxis( + beginMask, begin, strides, inputShape, axis, ellipsisMask); + normalizedEnd[axis] = + stopForAxis(endMask, end, strides, inputShape, axis, ellipsisMask); + normalizedStrides[axis] = stridesForAxis(strides, axis, ellipsisMask); + } + } + + return { + begin: normalizedBegin, + end: normalizedEnd, + strides: normalizedStrides + }; +} + // Creates full selection at the elided dimensions. If the dimension matches // the ellipsis mask, override the current start value. Otherwise, insert. export function startIndicesWithElidedDims( diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index 9fb70b98176..581dee4646b 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -26,7 +26,8 @@ import {TensorLike} from '../types'; import {op} from './operation'; import {reshape} from './reshape'; import {slice} from './slice'; -import {computeOutShape, maskToAxes, startForAxis, startIndicesWithElidedDims, stopForAxis, stopIndicesWithElidedDims, stridesForAxis, stridesWithElidedDims} from './slice_util'; +import {computeOutShape, getNormalizedAxes, maskToAxes} from './slice_util'; + /** * Extracts a strided slice of a tensor. @@ -98,28 +99,17 @@ function stridedSlice_( }); $x = reshape($x, newShape); - // Normalize the start, end and strides. - if (ellipsisAxes.length && numInterpolatedAxes > 0) { - const fullIndex = ellipsisAxes[0]; - - // The ellipsis applies to the masked index as well as any dimensions - // that are interpolated. - const numElidedAxes = numInterpolatedAxes + 1; - begin = startIndicesWithElidedDims( - beginMask, fullIndex, numElidedAxes, begin, $x.shape); - end = stopIndicesWithElidedDims( - endMask, fullIndex, numElidedAxes, end, $x.shape); - strides = - stridesWithElidedDims(strides, fullIndex, numElidedAxes, $x.shape); - } else { - for (let axis = 0; axis < $x.rank; axis++) { - begin[axis] = startForAxis( - beginMask, begin, strides, $x.shape, axis, ellipsisMask); - end[axis] = - stopForAxis(endMask, end, strides, $x.shape, axis, ellipsisMask); - strides[axis] = stridesForAxis(strides, axis, ellipsisMask); - } - } + const { + begin: normalizedBegin, + end: normalizedEnd, + strides: normalizedStrides + } = + getNormalizedAxes( + $x.shape, ellipsisAxes, numInterpolatedAxes, begin, end, strides, + beginMask, endMask, ellipsisMask); + begin = normalizedBegin; + end = normalizedEnd; + strides = normalizedStrides; const shrinkAxes = maskToAxes(shrinkAxisMask); // Adjust the ends based on the shrink mask. From c15eca5e631324567c3bd4c471d5bb0944a1bc68 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 10 Sep 2020 10:47:41 -0400 Subject: [PATCH 10/11] clean --- tfjs-core/src/ops/slice_util.ts | 3 ++- tfjs-core/src/ops/strided_slice.ts | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 8e608d585c1..bd236038c96 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -106,7 +106,8 @@ function getElidedAxes(numElidedAxes: number, ellipsisInsertionIndex: number) { export function getNormalizedAxes( inputShape: number[], ellipsisAxes: number[], numInterpolatedAxes: number, begin: number[], end: number[], strides: number[], beginMask: number, - endMask: number, ellipsisMask: number) { + endMask: number, + ellipsisMask: number): {begin: number[], end: number[], strides: number[]} { const inputRank = inputShape.length; let normalizedBegin = new Array(inputRank), normalizedEnd = new Array(inputRank), diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index 581dee4646b..4b94cccb5e1 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -28,7 +28,6 @@ import {reshape} from './reshape'; import {slice} from './slice'; import {computeOutShape, getNormalizedAxes, maskToAxes} from './slice_util'; - /** * Extracts a strided slice of a tensor. * From 97a5308e3aa0c1557b4808f464d50e7a7b26b3e5 Mon Sep 17 00:00:00 2001 From: Ann Yuan Date: Thu, 10 Sep 2020 16:36:42 -0400 Subject: [PATCH 11/11] mv --- tfjs-backend-wasm/src/kernels/StridedSlice.ts | 6 +++++- tfjs-core/src/ops/strided_slice.ts | 8 ++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/tfjs-backend-wasm/src/kernels/StridedSlice.ts b/tfjs-backend-wasm/src/kernels/StridedSlice.ts index c6c620c9aa1..bc55b8fa827 100644 --- a/tfjs-backend-wasm/src/kernels/StridedSlice.ts +++ b/tfjs-backend-wasm/src/kernels/StridedSlice.ts @@ -1,6 +1,6 @@ /** * @license - * Copyright 2019 Google LLC. All Rights Reserved. + * Copyright 2020 Google LLC. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -51,6 +51,10 @@ export function stridedSlice(args: { const {x} = inputs; let {begin, end, strides} = attrs; + if (strides == null) { + strides = new Array(begin.length); + } + const {beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask} = attrs; const ellipsisAxes = backend_util.slice_util.maskToAxes(ellipsisMask); diff --git a/tfjs-core/src/ops/strided_slice.ts b/tfjs-core/src/ops/strided_slice.ts index bb6b642cfae..d99020c16ff 100644 --- a/tfjs-core/src/ops/strided_slice.ts +++ b/tfjs-core/src/ops/strided_slice.ts @@ -65,13 +65,13 @@ function stridedSlice_( x: Tensor|TensorLike, begin: number[], end: number[], strides?: number[], beginMask = 0, endMask = 0, ellipsisMask = 0, newAxisMask = 0, shrinkAxisMask = 0): Tensor { - if (strides == null) { - strides = new Array(begin.length); - } - let $x = convertToTensor(x, 'x', 'stridedSlice'); const forward: ForwardFunc = (backend) => { + if (strides == null) { + strides = new Array(begin.length); + } + const ellipsisAxes = maskToAxes(ellipsisMask); if (ellipsisAxes.length > 1) { throw new Error('Multiple ellipses in slice is not allowed.');