diff --git a/tfjs-converter/src/executor/tensor_array.ts b/tfjs-converter/src/executor/tensor_array.ts index 83158aaa171..2af5f546d12 100644 --- a/tfjs-converter/src/executor/tensor_array.ts +++ b/tfjs-converter/src/executor/tensor_array.ts @@ -67,9 +67,9 @@ export class TensorArray { throw new Error(`TensorArray ${this.name} has already been closed.`); } - if (index < 0 || index >= this.tensors.length) { + if (index < 0 || index >= this.size()) { throw new Error(`Tried to read from index ${index}, but array size is: ${ - this.tensors.length}`); + this.size()}`); } const tensorWithState = this.tensors[index]; @@ -182,6 +182,8 @@ export class TensorArray { for (let i = 0; i < this.size(); i++) { indices.push(i); } + } else { + indices = indices.slice(0, this.size()); } if (indices.length === 0) { diff --git a/tfjs-converter/src/executor/tensor_array_test.ts b/tfjs-converter/src/executor/tensor_array_test.ts index b569e19860d..15b95c3bdc2 100644 --- a/tfjs-converter/src/executor/tensor_array_test.ts +++ b/tfjs-converter/src/executor/tensor_array_test.ts @@ -145,6 +145,11 @@ describe('TensorArray', () => { expect(gathered.shape).toEqual([2, 1, 1]); test_util.expectArraysClose(await gathered.data(), [2, 1]); }); + it('should return when indices longer than available tensors', async () => { + const gathered = tensorArray.gather([1, 0, 2, 3]); + expect(gathered.shape).toEqual([2, 1, 1]); + test_util.expectArraysClose(await gathered.data(), [2, 1]); + }); it('should fail if dtype is not matched', () => { expect(() => tensorArray.gather([0, 1], 'float32')).toThrow(); });