From 7127840d823684e13f31b42c50e741489b9fd4ef Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 28 Apr 2020 09:45:44 -0700 Subject: [PATCH] fix tensor array gather not be able to handle indices longer than available tensors --- tfjs-converter/src/executor/tensor_array.ts | 6 ++++-- tfjs-converter/src/executor/tensor_array_test.ts | 5 +++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tfjs-converter/src/executor/tensor_array.ts b/tfjs-converter/src/executor/tensor_array.ts index 83158aaa171..2af5f546d12 100644 --- a/tfjs-converter/src/executor/tensor_array.ts +++ b/tfjs-converter/src/executor/tensor_array.ts @@ -67,9 +67,9 @@ export class TensorArray { throw new Error(`TensorArray ${this.name} has already been closed.`); } - if (index < 0 || index >= this.tensors.length) { + if (index < 0 || index >= this.size()) { throw new Error(`Tried to read from index ${index}, but array size is: ${ - this.tensors.length}`); + this.size()}`); } const tensorWithState = this.tensors[index]; @@ -182,6 +182,8 @@ export class TensorArray { for (let i = 0; i < this.size(); i++) { indices.push(i); } + } else { + indices = indices.slice(0, this.size()); } if (indices.length === 0) { diff --git a/tfjs-converter/src/executor/tensor_array_test.ts b/tfjs-converter/src/executor/tensor_array_test.ts index b569e19860d..15b95c3bdc2 100644 --- a/tfjs-converter/src/executor/tensor_array_test.ts +++ b/tfjs-converter/src/executor/tensor_array_test.ts @@ -145,6 +145,11 @@ describe('TensorArray', () => { expect(gathered.shape).toEqual([2, 1, 1]); test_util.expectArraysClose(await gathered.data(), [2, 1]); }); + it('should return when indices longer than available tensors', async () => { + const gathered = tensorArray.gather([1, 0, 2, 3]); + expect(gathered.shape).toEqual([2, 1, 1]); + test_util.expectArraysClose(await gathered.data(), [2, 1]); + }); it('should fail if dtype is not matched', () => { expect(() => tensorArray.gather([0, 1], 'float32')).toThrow(); });