Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 5 additions & 29 deletions tfjs-converter/src/executor/tensor_array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
* =============================================================================
*/

import {concat, DataType, slice, stack, Tensor, tensor, tidy, unstack, util} from '@tensorflow/tfjs-core';
import {concat, DataType, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';
import {assertShapesMatchAllowUndefinedSize} from './tensor_utils';

export interface TensorWithState {
tensor?: Tensor;
Expand Down Expand Up @@ -125,7 +126,7 @@ export class TensorArray {
this.elementShape = tensor.shape;
}

this.assertShapesMatchAllowUndefinedSize(
assertShapesMatchAllowUndefinedSize(
this.elementShape, tensor.shape,
`TensorArray ${this.name}: Could not write to TensorArray index ${
index}.`);
Expand Down Expand Up @@ -194,7 +195,7 @@ export class TensorArray {
// their memory.
const tensors = this.readMany(indices);

this.assertShapesMatchAllowUndefinedSize(
assertShapesMatchAllowUndefinedSize(
this.elementShape, tensors[0].shape, 'TensorArray shape mismatch: ');

return stack(tensors, 0);
Expand All @@ -220,7 +221,7 @@ export class TensorArray {
// Collect all the tensors from the tensors array.
const tensors = this.readMany(indices);

this.assertShapesMatchAllowUndefinedSize(
assertShapesMatchAllowUndefinedSize(
this.elementShape, tensors[0].shape,
`TensorArray shape mismatch: tensor array shape (${
this.elementShape}) vs first tensor shape (${tensors[0].shape})`);
Expand Down Expand Up @@ -303,29 +304,4 @@ export class TensorArray {
}
this.writeMany(indices, tensors);
}

/**
* This differs from util.assertShapesMatch in that it allows values of
* negative one, an undefined size of a dimensinon, in a shape to match
* anything.
*/
private assertShapesMatchAllowUndefinedSize(
shapeA: number[], shapeB: number[], errorMessagePrefix = ''): void {
util.assert(
this.shapesEqualAllowUndefinedSize(shapeA, shapeB),
() =>
errorMessagePrefix + ` Shapes ${shapeA} and ${shapeB} must match`);
}

private shapesEqualAllowUndefinedSize(n1: number[], n2: number[]) {
if (n1.length !== n2.length) {
return false;
}
for (let i = 0; i < n1.length; i++) {
if (n1[i] !== -1 && n2[i] !== -1 && n1[i] !== n2[i]) {
return false;
}
}
return true;
}
}
Loading