diff --git a/tfjs-converter/metadata/kernel2op.json b/tfjs-converter/metadata/kernel2op.json index 2f4016623f5..5bbc3857327 100644 --- a/tfjs-converter/metadata/kernel2op.json +++ b/tfjs-converter/metadata/kernel2op.json @@ -181,13 +181,15 @@ "fused.depthwiseConv2d" ], "Gather": [ - "gather" + "gather", + "cast" ], "GatherNd": [ "gatherND" ], "GatherV2": [ - "gather" + "gather", + "cast" ], "Greater": [ "greater" @@ -314,7 +316,9 @@ ], "Pack": [ "tidy", + "squeeze", "util.arraysEqual", + "reshape", "stack" ], "Pad": [ @@ -430,7 +434,8 @@ "spaceToBatchND" ], "SparseToDense": [ - "sparseToDense" + "sparseToDense", + "cast" ], "Split": [ "split" @@ -506,6 +511,7 @@ "unstack" ], "Where": [ + "cast", "whereAsync" ], "While": [], diff --git a/tfjs-converter/src/executor/tensor_array.ts b/tfjs-converter/src/executor/tensor_array.ts index 8b09681c27c..4ac46db6955 100644 --- a/tfjs-converter/src/executor/tensor_array.ts +++ b/tfjs-converter/src/executor/tensor_array.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {concat, DataType, keep, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; +import {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; import {assertShapesMatchAllowUndefinedSize} from './tensor_utils'; @@ -294,12 +294,12 @@ export class TensorArray { const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength; const tensors: Tensor[] = []; tidy(() => { - tensor = tensor.reshape([1, totalLength, elementPerRow]); + tensor = reshape(tensor, [1, totalLength, elementPerRow]); for (let i = 0; i < length.length; ++i) { const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1]; const indices = [0, previousLength, 0]; const sizes = [1, length[i], elementPerRow]; - tensors[i] = slice(tensor, indices, sizes).reshape(this.elementShape); + tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape); } return tensors; }); diff --git a/tfjs-converter/src/executor/tensor_list.ts b/tfjs-converter/src/executor/tensor_list.ts index ea3e26e37d7..9eb3c45b27b 100644 --- a/tfjs-converter/src/executor/tensor_list.ts +++ b/tfjs-converter/src/executor/tensor_list.ts @@ -15,7 +15,7 @@ * ============================================================================= */ -import {concat, DataType, keep, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; +import {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core'; import {assertShapesMatchAllowUndefinedSize} from './tensor_utils'; @@ -114,7 +114,7 @@ export class TensorList { elementShape, this.elementShape, 'TensorList shape mismatch: '); return tidy(() => { const reshapedTensors = - this.tensors.map(tensor => tensor.reshape(elementShape)); + this.tensors.map(tensor => reshape(tensor, elementShape)); return stack(reshapedTensors, 0); }); } @@ -137,7 +137,7 @@ export class TensorList { const tensor = this.tensors.pop(); assertShapesMatchAllowUndefinedSize( tensor.shape, elementShape, 'TensorList shape mismatch: '); - return tensor.reshape(elementShape); + return reshape(tensor, elementShape); } /** @@ -254,7 +254,7 @@ export class TensorList { } return tidy(() => { - const tensors = indices.map(i => this.tensors[i].reshape(elementShape)); + const tensors = indices.map(i => reshape(this.tensors[i], elementShape)); return stack(tensors, 0); }); } @@ -278,7 +278,7 @@ export class TensorList { } return tidy(() => { - const tensors = this.tensors.map(t => t.reshape(elementShape)); + const tensors = this.tensors.map(t => reshape(t, elementShape)); return concat(tensors, 0); }); } @@ -304,7 +304,7 @@ export function fromTensor( assertShapesMatchAllowUndefinedSize( outputShape, elementShape, 'TensorList shape mismatch: '); - const tensorList: Tensor[] = tensor.unstack(); + const tensorList: Tensor[] = unstack(tensor); return new TensorList(tensorList, elementShape, dtype); } @@ -373,12 +373,12 @@ export function split( const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength; const tensors: Tensor[] = tidy(() => { const tensors = []; - tensor = tensor.reshape([1, totalLength, elementPerRow]); + tensor = reshape(tensor, [1, totalLength, elementPerRow]); for (let i = 0; i < length.length; ++i) { const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1]; const indices = [0, previousLength, 0]; const sizes = [1, length[i], elementPerRow]; - tensors[i] = slice(tensor, indices, sizes).reshape(elementShape); + tensors[i] = reshape(slice(tensor, indices, sizes), elementShape); } tensor.dispose(); return tensors; diff --git a/tfjs-converter/src/operations/executors/dynamic_executor.ts b/tfjs-converter/src/operations/executors/dynamic_executor.ts index 326140573d0..2de8c19a86f 100644 --- a/tfjs-converter/src/operations/executors/dynamic_executor.ts +++ b/tfjs-converter/src/operations/executors/dynamic_executor.ts @@ -91,9 +91,9 @@ export const executeOp: InternalOpAsyncExecutor = async( iouThreshold, scoreThreshold)]; } case 'Where': { - const condition = - (getParamValue('condition', node, tensorMap, context) as tfc.Tensor) - .asType('bool'); + const condition = tfc.cast( + (getParamValue('condition', node, tensorMap, context) as tfc.Tensor), + 'bool'); const result = [await tfc.whereAsync(condition)]; condition.dispose(); return result; diff --git a/tfjs-converter/src/operations/executors/slice_join_executor.ts b/tfjs-converter/src/operations/executors/slice_join_executor.ts index b3b47e9c7a2..246887d4f1a 100644 --- a/tfjs-converter/src/operations/executors/slice_join_executor.ts +++ b/tfjs-converter/src/operations/executors/slice_join_executor.ts @@ -43,7 +43,7 @@ export const executeOp: InternalOpExecutor = (node: Node, const input = getParamValue('x', node, tensorMap, context) as tfc.Tensor; const indices = getParamValue('indices', node, tensorMap, context) as tfc.Tensor1D; - return [tfc.gather(input, indices.asType('int32'), axis)]; + return [tfc.gather(input, tfc.cast(indices, 'int32'), axis)]; } case 'ReverseV2': case 'Reverse': { @@ -89,14 +89,14 @@ export const executeOp: InternalOpExecutor = (node: Node, getParamValue('tensors', node, tensorMap, context) as tfc.Tensor[]; // Reshape the tensors to the first tensor's shape if they don't match. const shape = tensors[0].shape; - const squeezedShape = tensors[0].squeeze().shape; + const squeezedShape = tfc.squeeze(tensors[0]).shape; const mapped = tensors.map(tensor => { const sameShape = tfc.util.arraysEqual(tensor.shape, shape); if (!sameShape && - !tfc.util.arraysEqual(tensor.squeeze().shape, squeezedShape)) { + !tfc.util.arraysEqual(tfc.squeeze(tensor).shape, squeezedShape)) { throw new Error('the input tensors shape does not match'); } - return sameShape ? tensor : tensor.reshape(shape); + return sameShape ? tensor : tfc.reshape(tensor, shape); }); return [tfc.stack(mapped, axis)]; }); @@ -151,7 +151,7 @@ export const executeOp: InternalOpExecutor = (node: Node, indices, sparseValues, shape, sparseValues.dtype === defaultValue.dtype ? defaultValue : - defaultValue.asType(sparseValues.dtype))]; + tfc.cast(defaultValue, sparseValues.dtype))]; } default: throw TypeError(`Node type ${node.op} is not implemented`);