Skip to content

Commit 2bc0ceb

Browse files
authored
fix cumsum and strided slice bugs (#3638)
BUG
1 parent f8956ae commit 2bc0ceb

File tree

5 files changed

+20
-9
lines changed

5 files changed

+20
-9
lines changed

tfjs-converter/src/operations/executors/slice_join_executor.ts

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,13 +77,7 @@ export const executeOp: InternalOpExecutor = (node: Node,
7777
const shrinkAxisMask =
7878
getParamValue('shrinkAxisMask', node, tensorMap, context) as number;
7979
const tensor = getParamValue('x', node, tensorMap, context) as tfc.Tensor;
80-
if (begin.length === 1 && tensor.shape.length > 1) {
81-
for (let i = 1; i < tensor.shape.length; i++) {
82-
begin.push(0);
83-
end.push(tensor.shape[i]);
84-
strides.push(strides[0]);
85-
}
86-
}
80+
8781
return [tfc.stridedSlice(
8882
tensor, begin, end, strides, beginMask, endMask, ellipsisMask,
8983
newAxisMask, shrinkAxisMask)];

tfjs-core/src/ops/cumsum.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import {GradSaveFunc, NamedTensorMap} from '../tensor_types';
2424
import {convertToTensor} from '../tensor_util_env';
2525
import {TensorLike} from '../types';
2626

27-
import {getAxesPermutation, getInnerMostAxes} from './axis_util';
27+
import {getAxesPermutation, getInnerMostAxes, getUndoAxesPermutation} from './axis_util';
2828
import {op} from './operation';
2929
import {transpose} from './transpose';
3030

@@ -66,7 +66,8 @@ function cumsum_<T extends Tensor>(
6666
save([$x]);
6767

6868
if (permutation != null) {
69-
value = transpose(value, permutation);
69+
const reversePermutation = getUndoAxesPermutation(permutation);
70+
value = transpose(value, reversePermutation);
7071
}
7172
return value;
7273
};

tfjs-core/src/ops/cumsum_test.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ describeWithFlags('cumsum', ALL_ENVS, () => {
9393
expectArraysClose(await res.data(), [0, 1, 2, 5, 4, 9, 6, 13]);
9494
});
9595

96+
it('handle permutation properly', async () => {
97+
const res = tf.ones([1, 240, 1, 10]).cumsum(1);
98+
expect(res.shape).toEqual([1, 240, 1, 10]);
99+
});
100+
96101
it('throws when passed a non-tensor', () => {
97102
expect(() => tf.cumsum({} as tf.Tensor))
98103
.toThrowError(/Argument 'x' passed to 'cumsum' must be a Tensor/);

tfjs-core/src/ops/slice_util.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ export function stopIndicesWithElidedDims(
150150
}
151151

152152
for (let i = 0; i < newIndices.length; i++) {
153+
// Handle negative indices
154+
const axisSize = inputShape[i];
155+
if (newIndices[i] < 0) {
156+
newIndices[i] += axisSize;
157+
}
153158
newIndices[i] = util.clamp(0, newIndices[i], inputShape[i]);
154159
}
155160
return newIndices;

tfjs-core/src/ops/strided_slice_test.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,12 @@ describeWithFlags('stridedSlice', ALL_ENVS, () => {
446446
.toThrowError(/Argument 'x' passed to 'stridedSlice' must be a Tensor/);
447447
});
448448

449+
it('stridedSlice should handle negative end with ellipsisMask', () => {
450+
const a = tf.ones([1, 240, 1, 10]);
451+
const output =
452+
tf.stridedSlice(a, [0, 0, 0], [0, -1, 0], [1, 1, 1], 3, 1, 4);
453+
expect(output.shape).toEqual([1, 239, 1, 10]);
454+
});
449455
it('accepts a tensor-like object', async () => {
450456
const tensor = [0, 1, 2, 3];
451457
const output = tf.stridedSlice(tensor, [0], [3], [2]);

0 commit comments

Comments
 (0)