Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions tfjs-converter/metadata/kernel2op.json
Original file line number Diff line number Diff line change
Expand Up @@ -181,13 +181,15 @@
"fused.depthwiseConv2d"
],
"Gather": [
"gather"
"gather",
"cast"
],
"GatherNd": [
"gatherND"
],
"GatherV2": [
"gather"
"gather",
"cast"
],
"Greater": [
"greater"
Expand Down Expand Up @@ -314,7 +316,9 @@
],
"Pack": [
"tidy",
"squeeze",
"util.arraysEqual",
"reshape",
"stack"
],
"Pad": [
Expand Down Expand Up @@ -430,7 +434,8 @@
"spaceToBatchND"
],
"SparseToDense": [
"sparseToDense"
"sparseToDense",
"cast"
],
"Split": [
"split"
Expand Down Expand Up @@ -506,6 +511,7 @@
"unstack"
],
"Where": [
"cast",
"whereAsync"
],
"While": [],
Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/src/executor/tensor_array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {concat, DataType, keep, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';
import {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';

import {assertShapesMatchAllowUndefinedSize} from './tensor_utils';

Expand Down Expand Up @@ -294,12 +294,12 @@ export class TensorArray {
const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
const tensors: Tensor[] = [];
tidy(() => {
tensor = tensor.reshape([1, totalLength, elementPerRow]);
tensor = reshape(tensor, [1, totalLength, elementPerRow]);
for (let i = 0; i < length.length; ++i) {
const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
const indices = [0, previousLength, 0];
const sizes = [1, length[i], elementPerRow];
tensors[i] = slice(tensor, indices, sizes).reshape(this.elementShape);
tensors[i] = reshape(slice(tensor, indices, sizes), this.elementShape);
}
return tensors;
});
Expand Down
16 changes: 8 additions & 8 deletions tfjs-converter/src/executor/tensor_list.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {concat, DataType, keep, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';
import {concat, DataType, keep, reshape, scalar, slice, stack, Tensor, tensor, tidy, unstack} from '@tensorflow/tfjs-core';

import {assertShapesMatchAllowUndefinedSize} from './tensor_utils';

Expand Down Expand Up @@ -114,7 +114,7 @@ export class TensorList {
elementShape, this.elementShape, 'TensorList shape mismatch: ');
return tidy(() => {
const reshapedTensors =
this.tensors.map(tensor => tensor.reshape(elementShape));
this.tensors.map(tensor => reshape(tensor, elementShape));
return stack(reshapedTensors, 0);
});
}
Expand All @@ -137,7 +137,7 @@ export class TensorList {
const tensor = this.tensors.pop();
assertShapesMatchAllowUndefinedSize(
tensor.shape, elementShape, 'TensorList shape mismatch: ');
return tensor.reshape(elementShape);
return reshape(tensor, elementShape);
}

/**
Expand Down Expand Up @@ -254,7 +254,7 @@ export class TensorList {
}

return tidy(() => {
const tensors = indices.map(i => this.tensors[i].reshape(elementShape));
const tensors = indices.map(i => reshape(this.tensors[i], elementShape));
return stack(tensors, 0);
});
}
Expand All @@ -278,7 +278,7 @@ export class TensorList {
}

return tidy(() => {
const tensors = this.tensors.map(t => t.reshape(elementShape));
const tensors = this.tensors.map(t => reshape(t, elementShape));
return concat(tensors, 0);
});
}
Expand All @@ -304,7 +304,7 @@ export function fromTensor(
assertShapesMatchAllowUndefinedSize(
outputShape, elementShape, 'TensorList shape mismatch: ');

const tensorList: Tensor[] = tensor.unstack();
const tensorList: Tensor[] = unstack(tensor);
return new TensorList(tensorList, elementShape, dtype);
}

Expand Down Expand Up @@ -373,12 +373,12 @@ export function split(
const elementPerRow = totalLength === 0 ? 0 : tensor.size / totalLength;
const tensors: Tensor[] = tidy(() => {
const tensors = [];
tensor = tensor.reshape([1, totalLength, elementPerRow]);
tensor = reshape(tensor, [1, totalLength, elementPerRow]);
for (let i = 0; i < length.length; ++i) {
const previousLength = (i === 0) ? 0 : cumulativeLengths[i - 1];
const indices = [0, previousLength, 0];
const sizes = [1, length[i], elementPerRow];
tensors[i] = slice(tensor, indices, sizes).reshape(elementShape);
tensors[i] = reshape(slice(tensor, indices, sizes), elementShape);
}
tensor.dispose();
return tensors;
Expand Down
6 changes: 3 additions & 3 deletions tfjs-converter/src/operations/executors/dynamic_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ export const executeOp: InternalOpAsyncExecutor = async(
iouThreshold, scoreThreshold)];
}
case 'Where': {
const condition =
(getParamValue('condition', node, tensorMap, context) as tfc.Tensor)
.asType('bool');
const condition = tfc.cast(
(getParamValue('condition', node, tensorMap, context) as tfc.Tensor),
'bool');
const result = [await tfc.whereAsync(condition)];
condition.dispose();
return result;
Expand Down
10 changes: 5 additions & 5 deletions tfjs-converter/src/operations/executors/slice_join_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ export const executeOp: InternalOpExecutor = (node: Node,
const input = getParamValue('x', node, tensorMap, context) as tfc.Tensor;
const indices =
getParamValue('indices', node, tensorMap, context) as tfc.Tensor1D;
return [tfc.gather(input, indices.asType('int32'), axis)];
return [tfc.gather(input, tfc.cast(indices, 'int32'), axis)];
}
case 'ReverseV2':
case 'Reverse': {
Expand Down Expand Up @@ -89,14 +89,14 @@ export const executeOp: InternalOpExecutor = (node: Node,
getParamValue('tensors', node, tensorMap, context) as tfc.Tensor[];
// Reshape the tensors to the first tensor's shape if they don't match.
const shape = tensors[0].shape;
const squeezedShape = tensors[0].squeeze().shape;
const squeezedShape = tfc.squeeze(tensors[0]).shape;
const mapped = tensors.map(tensor => {
const sameShape = tfc.util.arraysEqual(tensor.shape, shape);
if (!sameShape &&
!tfc.util.arraysEqual(tensor.squeeze().shape, squeezedShape)) {
!tfc.util.arraysEqual(tfc.squeeze(tensor).shape, squeezedShape)) {
throw new Error('the input tensors shape does not match');
}
return sameShape ? tensor : tensor.reshape(shape);
return sameShape ? tensor : tfc.reshape(tensor, shape);
});
return [tfc.stack(mapped, axis)];
});
Expand Down Expand Up @@ -151,7 +151,7 @@ export const executeOp: InternalOpExecutor = (node: Node,
indices, sparseValues, shape,
sparseValues.dtype === defaultValue.dtype ?
defaultValue :
defaultValue.asType(sparseValues.dtype))];
tfc.cast(defaultValue, sparseValues.dtype))];
}
default:
throw TypeError(`Node type ${node.op} is not implemented`);
Expand Down