Skip to content

Commit

Permalink
fix wasm strided slice with new axis (#4235)
Browse files Browse the repository at this point in the history
BUG
BUG

Co-authored-by: Ping Yu <4018+pyu10055@users.noreply.github.com>
  • Loading branch information
pvanhaes and pyu10055 committed Mar 2, 2021
1 parent 93e00bc commit bf1a062
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tfjs-backend-wasm/src/kernels/StridedSlice.ts
Expand Up @@ -111,7 +111,8 @@ export function stridedSlice(args: {

const nonStrided = strides.every(v => v === 1);
if (nonStrided) {
const xSliced = slice({inputs: {x}, attrs: {begin, size}, backend});
const xSliced = slice(
{inputs: {x: xReshaped}, attrs: {begin, size}, backend});
backend.disposeData(xReshaped.dataId);
const reshaped =
reshape({inputs: {x: xSliced}, attrs: {shape: outShape}, backend});
Expand Down
3 changes: 2 additions & 1 deletion tfjs-core/src/ops/strided_slice_test.ts
Expand Up @@ -134,7 +134,7 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => {
expectArraysClose(await output.data(), [0, 2]);
});

it('strided slice with several new axes', () => {
it('strided slice with several new axes', async () => {
// Python slice code: t[1:2,tf.newaxis,0:3,tf.newaxis,2:5]
const t = tf.zeros([2, 3, 4, 5]);
const begin = [1, 0, 0, 0, 2];
Expand All @@ -147,6 +147,7 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => {
const output = tf.stridedSlice(
t, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask);
expect(output.shape).toEqual([1, 1, 3, 1, 2, 5]);
expectArraysClose(await output.data(), new Array(30).fill(0));
});

it('strided slice with new axes and shrink axes', () => {
Expand Down

0 comments on commit bf1a062

Please sign in to comment.