diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index e7a0965bd4c..99f8bf22f98 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -77,13 +77,7 @@ export const executeOp: InternalOpExecutor = (node: Node, const shrinkAxisMask = getParamValue('shrinkAxisMask', node, tensorMap, context) as number; const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor; - if (begin.length === 1 && tensor.shape.length > 1) { - for (let i = 1; i < tensor.shape.length; i++) { - begin.push(0); - end.push(tensor.shape[i]); - strides.push(strides[0]); - } - } + return [tfc.stridedSlice( tensor, begin, end, strides, beginMask, endMask, ellipsisMask, newAxisMask, shrinkAxisMask)]; diff --git a/tfjs-core/src/ops/cumsum.ts b/tfjs-core/src/ops/cumsum.ts index 857a4f62cfc..96a62e32ded 100644 --- a/tfjs-core/src/ops/cumsum.ts +++ b/tfjs-core/src/ops/cumsum.ts @@ -24,7 +24,7 @@ import {GradSaveFunc, NamedTensorMap} from '../tensor_types'; import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; -import {getAxesPermutation, getInnerMostAxes} from './axis_util'; +import {getAxesPermutation, getInnerMostAxes, getUndoAxesPermutation} from './axis_util'; import {op} from './operation'; import {transpose} from './transpose'; @@ -66,7 +66,8 @@ function cumsum_( save([$x]); if (permutation != null) { - value = transpose(value, permutation); + const reversePermutation = getUndoAxesPermutation(permutation); + value = transpose(value, reversePermutation); } return value; }; diff --git a/tfjs-core/src/ops/cumsum_test.ts b/tfjs-core/src/ops/cumsum_test.ts index b6c4205ac0d..b56e86f62a9 100644 --- a/tfjs-core/src/ops/cumsum_test.ts +++ b/tfjs-core/src/ops/cumsum_test.ts @@ -93,6 +93,11 @@ describeWithFlags('cumsum', ALL_ENVS, () => { expectArraysClose(await res.data(), [0, 1, 2, 5, 4, 9, 6, 13]); }); + it('handle permutation properly', async () => { + const res = tf.ones([1, 240, 1, 10]).cumsum(1); + expect(res.shape).toEqual([1, 240, 1, 10]); + }); + it('throws when passed a non-tensor', () => { expect(() => tf.cumsum({} as tf.Tensor)) .toThrowError(/Argument 'x' passed to 'cumsum' must be a Tensor/); diff --git a/tfjs-core/src/ops/slice_util.ts b/tfjs-core/src/ops/slice_util.ts index 9889cbb4c26..10799d2a57b 100644 --- a/tfjs-core/src/ops/slice_util.ts +++ b/tfjs-core/src/ops/slice_util.ts @@ -150,6 +150,11 @@ export function stopIndicesWithElidedDims( } for (let i = 0; i < newIndices.length; i++) { + // Handle negative indices + const axisSize = inputShape[i]; + if (newIndices[i] < 0) { + newIndices[i] += axisSize; + } newIndices[i] = util.clamp(0, newIndices[i], inputShape[i]); } return newIndices; diff --git a/tfjs-core/src/ops/strided_slice_test.ts b/tfjs-core/src/ops/strided_slice_test.ts index 49a1dc2754c..4e4d2d3be7d 100644 --- a/tfjs-core/src/ops/strided_slice_test.ts +++ b/tfjs-core/src/ops/strided_slice_test.ts @@ -446,6 +446,12 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => { .toThrowError(/Argument 'x' passed to 'stridedSlice' must be a Tensor/); }); + it('stridedSlice should handle negative end with ellipsisMask', () => { + const a = tf.ones([1, 240, 1, 10]); + const output = + tf.stridedSlice(a, [0, 0, 0], [0, -1, 0], [1, 1, 1], 3, 1, 4); + expect(output.shape).toEqual([1, 239, 1, 10]); + }); it('accepts a tensor-like object', async () => { const tensor = [0, 1, 2, 3]; const output = tf.stridedSlice(tensor, [0], [3], [2]);