Skip to content
This repository has been archived by the owner on Aug 15, 2019. It is now read-only.

adds error checking for tf.Buffer.get(outOfRangeLocation) #1630

Merged
merged 2 commits into from Mar 18, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/buffer_test.ts
Expand Up @@ -34,6 +34,16 @@ describeWithFlags('tf.buffer', ALL_ENVS, () => {
expectArraysClose(buff.values, new Float32Array([1.3, 0, 0, 2.9, 0, 0]));
});

it('get() out of range throws', () => {
const t = tf.tensor([1, 2, 3, 4, 5, 6, 7, 8], [2, 2, 2]);

const buff = t.bufferSync();
expect(buff.get(0, 0, 0)).toBeCloseTo(1);
expect(buff.get(0, 0, 1)).toBeCloseTo(2);
expect(() => buff.get(0, 0, 2))
.toThrowError(/Requested out of range element/);
});

it('int32', () => {
const buff = tf.buffer([2, 3], 'int32');
buff.set(1.3, 0, 0);
Expand Down
7 changes: 7 additions & 0 deletions src/tensor.ts
Expand Up @@ -98,6 +98,13 @@ export class TensorBuffer<R extends Rank, D extends DataType = 'float32'> {
if (locs.length === 0) {
locs = [0];
}
for (const i in locs) {
if (locs[i] < 0 || locs[i] >= this.shape[i]) {
const msg = `Requested out of range element at ${locs}. ` +
` Buffer shape=${this.shape}`;
throw new Error(msg);
}
}
let index = locs[locs.length - 1];
for (let i = 0; i < locs.length - 1; ++i) {
index += this.strides[i] * locs[i];
Expand Down