diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 90f248c4a3c..2e4871f89c1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -87,6 +87,7 @@ To upgrade the version of TensorFlow that is embedded within TensorFlow Java, pl 7. Copy the content of `tensorflow-x.x.x/.bazelversion` file to `tensorflow-core/tensorflow-core-native/.bazelversion` 8. Validate that options in `tensorflow-core/tensorflow-core-native/.bazelrc` are still accurate or update them accordingly 9. Update URLs of existing TensorFlow binaries in the `tensorflow-core/tensorflow-core-native/scripts/dist_download` script +10. Update URLs of TensorFlow-Text binaries used for testing in the `tensorflow-core/tensorflow-core-api/scripts/test_download` script #### Patching TensorFlow Sources diff --git a/README.md b/README.md index efb6e31be97..ff5d158d748 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,7 @@ This table shows the mapping between TensorFlow, TensorFlow Java and minimum sup | 0.4.2 | 2.7.4 | 8 | | 0.5.0 | 2.10.1 | 11 | | 0.6.0-SNAPSHOT | 2.10.1 | 11 | -| 1.0.0-SNAPSHOT | 2.15.0 | 11 | +| 1.0.0-SNAPSHOT | 2.16.1 | 11 | ## How to Contribute? diff --git a/pom.xml b/pom.xml index f86e92bc69a..26aef342cbd 100644 --- a/pom.xml +++ b/pom.xml @@ -43,7 +43,7 @@ 5.10.0 1.37 2.7 - 2.10.0 + 2.25.0 true true true diff --git a/tensorflow-core/tensorflow-core-api/scripts/test_download.sh b/tensorflow-core/tensorflow-core-api/scripts/test_download.sh index 146868f26d6..03e4c854285 100755 --- a/tensorflow-core/tensorflow-core-api/scripts/test_download.sh +++ b/tensorflow-core/tensorflow-core-api/scripts/test_download.sh @@ -5,10 +5,10 @@ DOWNLOAD_FOLDER="$1" case ${PLATFORM:-} in 'linux-x86_64') - TEXT_WHEEL_URL='https://files.pythonhosted.org/packages/20/a0/bdbf2a11141f1c93e572364d13c42537cfe811b747a0bbb58fdd904f3960/tensorflow_text-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl' + TEXT_WHEEL_URL='https://files.pythonhosted.org/packages/c5/ef/5b8270e5665923bda4222f56382d9fbcb7fd6efd5fb8557ad0776848cdff/tensorflow_text-2.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl' ;; 'macosx-x86_64') - TEXT_WHEEL_URL='https://files.pythonhosted.org/packages/8a/fe/a2f19d3d3ab834c3fa1007c970b0b86573beb929c86ca6c85cd13e86e4b2/tensorflow_text-2.15.0-cp311-cp311-macosx_10_9_x86_64.whl' + TEXT_WHEEL_URL='https://files.pythonhosted.org/packages/ed/5d/b55f48cdf98a164d293f660748c2501ea828e148250a4cadbb5b0d988735/tensorflow_text-2.16.1-cp311-cp311-macosx_10_9_x86_64.whl' ;; *) echo "TensorFlow Text distribution for ${PLATFORM} is not supported for download" diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_ComputeDedupDataSize.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_ComputeDedupDataSize.pbtxt new file mode 100644 index 00000000000..91551fab016 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_ComputeDedupDataSize.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "ComputeDedupDataSize" + visibility: VISIBLE + endpoint { + name: "tpu.ComputeDedupDataSize" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_ConvertToCooTensor.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_ConvertToCooTensor.pbtxt new file mode 100644 index 00000000000..3047ada98b7 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_ConvertToCooTensor.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "ConvertToCooTensor" + visibility: VISIBLE + endpoint { + name: "tpu.ConvertToCooTensor" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_DatasetFingerprint.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_DatasetFingerprint.pbtxt new file mode 100644 index 00000000000..61e0086729b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_DatasetFingerprint.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "DatasetFingerprint" + visibility: VISIBLE + endpoint { + name: "data.DatasetFingerprint" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt new file mode 100644 index 00000000000..a9a8710fdb5 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_GetMinibatchSplitsWithPhysicalReplica.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "GetMinibatchSplitsWithPhysicalReplica" + visibility: VISIBLE + endpoint { + name: "tpu.GetMinibatchSplitsWithPhysicalReplica" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt new file mode 100644 index 00000000000..9ee5d7e2e5b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_GetMinibatchesInCsrWithPhysicalReplica.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "GetMinibatchesInCsrWithPhysicalReplica" + visibility: VISIBLE + endpoint { + name: "tpu.GetMinibatchesInCsrWithPhysicalReplica" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_GlobalIterId.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_GlobalIterId.pbtxt new file mode 100644 index 00000000000..8a795a8ef23 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_GlobalIterId.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "GlobalIterId" + visibility: VISIBLE + endpoint { + name: "tpu.GlobalIterId" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_ListSnapshotChunksDataset.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_ListSnapshotChunksDataset.pbtxt new file mode 100644 index 00000000000..e48c7bf0a11 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_ListSnapshotChunksDataset.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "ListSnapshotChunksDataset" + visibility: VISIBLE + endpoint { + name: "data.ListSnapshotChunksDataset" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_StoreMinibatchStatisticsInFdo.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_StoreMinibatchStatisticsInFdo.pbtxt new file mode 100644 index 00000000000..f6d523d7da2 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_StoreMinibatchStatisticsInFdo.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "StoreMinibatchStatisticsInFdo" + visibility: VISIBLE + endpoint { + name: "tpu.StoreMinibatchStatisticsInFdo" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt new file mode 100644 index 00000000000..f7502de956c --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_TPUAnnotateTensorsWithDynamicShape.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "TPUAnnotateTensorsWithDynamicShape" + visibility: VISIBLE + endpoint { + name: "tpu.TPUAnnotateTensorsWithDynamicShape" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_TPUCopyWithDynamicShape.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_TPUCopyWithDynamicShape.pbtxt new file mode 100644 index 00000000000..d26aa20a655 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_TPUCopyWithDynamicShape.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "TPUCopyWithDynamicShape" + visibility: VISIBLE + endpoint { + name: "tpu.TPUCopyWithDynamicShape" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdagrad.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdagrad.pbtxt new file mode 100644 index 00000000000..fad2b3f25f0 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdagrad.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseCoreAdagrad" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseCoreAdagrad" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdagradMomentum.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdagradMomentum.pbtxt new file mode 100644 index 00000000000..c5322e9d0dc --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdagradMomentum.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseCoreAdagradMomentum" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseCoreAdagradMomentum" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdam.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdam.pbtxt new file mode 100644 index 00000000000..f9333b7ef14 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreAdam.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseCoreAdam" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseCoreAdam" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreFtrl.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreFtrl.pbtxt new file mode 100644 index 00000000000..8513db03498 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreFtrl.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseCoreFtrl" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseCoreFtrl" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreSgd.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreSgd.pbtxt new file mode 100644 index 00000000000..b7f11b804eb --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseCoreSgd.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseCoreSgd" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseCoreSgd" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmul.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmul.pbtxt new file mode 100644 index 00000000000..744371be42b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmul.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseDenseMatmul" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseDenseMatmul" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt new file mode 100644 index 00000000000..404c6711e3f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdagradAndCsrInput.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseDenseMatmulGradWithAdagradAndCsrInput" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseDenseMatmulGradWithAdagradAndCsrInput" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt new file mode 100644 index 00000000000..c2c77111f0a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt new file mode 100644 index 00000000000..70bd843e8bf --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithAdamAndCsrInput.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseDenseMatmulGradWithAdamAndCsrInput" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseDenseMatmulGradWithAdamAndCsrInput" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt new file mode 100644 index 00000000000..24d83334982 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithFtrlAndCsrInput.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseDenseMatmulGradWithFtrlAndCsrInput" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseDenseMatmulGradWithFtrlAndCsrInput" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt new file mode 100644 index 00000000000..606ac6feb41 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulGradWithSgdAndCsrInput.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseDenseMatmulGradWithSgdAndCsrInput" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseDenseMatmulGradWithSgdAndCsrInput" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt new file mode 100644 index 00000000000..00a2bed7460 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/api/api_def_XlaSparseDenseMatmulWithCsrInput.pbtxt @@ -0,0 +1,7 @@ +op { + graph_op_name: "XlaSparseDenseMatmulWithCsrInput" + visibility: VISIBLE + endpoint { + name: "xla.XlaSparseDenseMatmulWithCsrInput" + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java index 3f682c82355..80f23c71450 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/DataOps.java @@ -35,6 +35,7 @@ import org.tensorflow.op.data.ConcatenateDataset; import org.tensorflow.op.data.DataServiceDataset; import org.tensorflow.op.data.DatasetCardinality; +import org.tensorflow.op.data.DatasetFingerprint; import org.tensorflow.op.data.DatasetFromGraph; import org.tensorflow.op.data.DatasetToGraph; import org.tensorflow.op.data.DatasetToSingleElement; @@ -62,6 +63,7 @@ import org.tensorflow.op.data.LMDBDataset; import org.tensorflow.op.data.LatencyStatsDataset; import org.tensorflow.op.data.LegacyParallelInterleaveDataset; +import org.tensorflow.op.data.ListSnapshotChunksDataset; import org.tensorflow.op.data.LoadDataset; import org.tensorflow.op.data.MakeIterator; import org.tensorflow.op.data.MapAndBatchDataset; @@ -390,6 +392,17 @@ public DatasetCardinality datasetCardinality(Operand inputDatas return DatasetCardinality.create(scope, inputDataset, options); } + /** + * Returns the fingerprint of {@code input_dataset}. + * Returns the fingerprint of {@code input_dataset}. + * + * @param inputDataset A variant tensor representing the dataset to return fingerprint for. + * @return a new instance of DatasetFingerprint + */ + public DatasetFingerprint datasetFingerprint(Operand inputDataset) { + return DatasetFingerprint.create(scope, inputDataset); + } + /** * Creates a dataset from the given {@code graph_def}. * Creates a dataset from the provided {@code graph_def}. @@ -873,6 +886,19 @@ public LegacyParallelInterleaveDataset legacyParallelInterleaveDataset( return LegacyParallelInterleaveDataset.create(scope, inputDataset, otherArguments, cycleLength, blockLength, bufferOutputElements, prefetchInputElements, f, outputTypes, outputShapes, options); } + /** + * The ListSnapshotChunksDataset operation + * + * @param snapshotPath The snapshotPath value + * @param outputTypes The value of the outputTypes attribute + * @param outputShapes The value of the outputShapes attribute + * @return a new instance of ListSnapshotChunksDataset + */ + public ListSnapshotChunksDataset listSnapshotChunksDataset(Operand snapshotPath, + List> outputTypes, List outputShapes) { + return ListSnapshotChunksDataset.create(scope, snapshotPath, outputTypes, outputShapes); + } + /** * The LoadDataset operation * diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index 5dd7f842b47..91171014690 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -2338,7 +2338,7 @@ public DynamicPartition dynamicPartition(Operand data, * must have {@code data[i].shape = indices[i].shape + constant}. In terms of this * {@code constant}, the output shape is *
-   *  merged.shape = [max(indices)] + constant
+   *  merged.shape = [max(indices) + 1] + constant
    *  
*

Values are merged in order, so if an index appears in both {@code indices[m][i]} and * {@code indices[n][j]} for {@code (m,i) < (n,j)} the slice {@code data[n][j]} will appear in the @@ -7599,7 +7599,7 @@ public TensorStridedSliceUpdate tensorSt * * * @param data type for {@code output} output - * @param input 1-D or higher. + * @param input Can be of any rank. * @param multiples 1-D. Length must be the same as the number of dimensions in {@code input} * @param data type for {@code Tile} output and operands * @return a new instance of Tile diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java index 83e6f9d60b0..9e2411bc0a1 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/TpuOps.java @@ -24,19 +24,29 @@ import org.tensorflow.op.tpu.CollateTPUEmbeddingMemory; import org.tensorflow.op.tpu.Compile; import org.tensorflow.op.tpu.CompileSucceededAssert; +import org.tensorflow.op.tpu.ComputeDedupDataSize; import org.tensorflow.op.tpu.ConfigureAndInitializeGlobalTPU; import org.tensorflow.op.tpu.ConfigureTPUEmbeddingHost; import org.tensorflow.op.tpu.ConfigureTPUEmbeddingMemory; import org.tensorflow.op.tpu.ConnectTPUEmbeddingHosts; +import org.tensorflow.op.tpu.ConvertToCooTensor; import org.tensorflow.op.tpu.DTensorRestore; import org.tensorflow.op.tpu.Execute; import org.tensorflow.op.tpu.ExecuteAndUpdateVariables; import org.tensorflow.op.tpu.ExecuteTPUEmbeddingPartitioner; import org.tensorflow.op.tpu.FinalizeTPUEmbedding; +import org.tensorflow.op.tpu.GetMinibatchSplitsWithPhysicalReplica; +import org.tensorflow.op.tpu.GetMinibatchesInCsrWithPhysicalReplica; +import org.tensorflow.op.tpu.GlobalIterId; import org.tensorflow.op.tpu.PartitionedOutput; import org.tensorflow.op.tpu.ShutdownTPUSystem; +import org.tensorflow.op.tpu.StoreMinibatchStatisticsInFdo; +import org.tensorflow.op.tpu.TPUAnnotateTensorsWithDynamicShape; +import org.tensorflow.op.tpu.TPUCopyWithDynamicShape; import org.tensorflow.op.tpu.TPURoundRobin; import org.tensorflow.op.tpu.TpuHandleToProtoKey; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; import org.tensorflow.types.TInt64; import org.tensorflow.types.TString; import org.tensorflow.types.family.TType; @@ -111,6 +121,19 @@ public CompileSucceededAssert compileSucceededAssert(Operand compilatio return CompileSucceededAssert.create(scope, compilationStatus); } + /** + * An op computes the size of the deduplication data from embedding core and returns the updated config. + * This op is to compute size of the deduplication data so to provide this + * information to the op that computes the tuple mask of deduplication data can + * have static output shape. + * + * @param config Serialized TPUEmbeddingConfiguration proto. + * @return a new instance of ComputeDedupDataSize + */ + public ComputeDedupDataSize computeDedupDataSize(String config) { + return ComputeDedupDataSize.create(scope, config); + } + /** * An op that sets up the centralized structures for a distributed TPU system. * @@ -163,6 +186,21 @@ public ConnectTPUEmbeddingHosts connectTPUEmbeddingHosts( return ConnectTPUEmbeddingHosts.create(scope, networkConfigs); } + /** + * The ConvertToCooTensor operation + * + * @param indicesOrRowSplits The indicesOrRowSplits value + * @param values The values value + * @param weights The weights value + * @param sampleCount The value of the sampleCount attribute + * @param combiner The value of the combiner attribute + * @return a new instance of ConvertToCooTensor + */ + public ConvertToCooTensor convertToCooTensor(Operand indicesOrRowSplits, + Operand values, Operand weights, Long sampleCount, String combiner) { + return ConvertToCooTensor.create(scope, indicesOrRowSplits, values, weights, sampleCount, combiner); + } + /** * The DTensorRestoreV2 operation * @@ -244,6 +282,66 @@ public FinalizeTPUEmbedding finalizeTPUEmbedding(Operand commonConfig, return FinalizeTPUEmbedding.create(scope, commonConfig, memoryConfig); } + /** + * The GetMinibatchSplitsWithPhysicalReplica operation + * + * @param programKey The programKey value + * @param rowIds The rowIds value + * @param colIds The colIds value + * @param gains The gains value + * @param sampleCount The value of the sampleCount attribute + * @param numReplica The value of the numReplica attribute + * @param tableVocabSize The value of the tableVocabSize attribute + * @param featureWidth The value of the featureWidth attribute + * @param numScPerChip The value of the numScPerChip attribute + * @param tableName The value of the tableName attribute + * @param miniBatchSplits The value of the miniBatchSplits attribute + * @return a new instance of GetMinibatchSplitsWithPhysicalReplica + */ + public GetMinibatchSplitsWithPhysicalReplica getMinibatchSplitsWithPhysicalReplica( + Operand programKey, Operand rowIds, Operand colIds, + Operand gains, Long sampleCount, Long numReplica, Long tableVocabSize, + Long featureWidth, Long numScPerChip, String tableName, String miniBatchSplits) { + return GetMinibatchSplitsWithPhysicalReplica.create(scope, programKey, rowIds, colIds, gains, sampleCount, numReplica, tableVocabSize, featureWidth, numScPerChip, tableName, miniBatchSplits); + } + + /** + * The GetMinibatchesInCsrWithPhysicalReplica operation + * + * @param programKey The programKey value + * @param rowIds The rowIds value + * @param colIds The colIds value + * @param gains The gains value + * @param splits The splits value + * @param idCounts The idCounts value + * @param sampleCount The value of the sampleCount attribute + * @param numReplica The value of the numReplica attribute + * @param maxMinibatchesPerSc The value of the maxMinibatchesPerSc attribute + * @param maxIdsPerChipPerSample The value of the maxIdsPerChipPerSample attribute + * @param tableVocabSize The value of the tableVocabSize attribute + * @param featureWidth The value of the featureWidth attribute + * @param numScPerChip The value of the numScPerChip attribute + * @param tableName The value of the tableName attribute + * @param miniBatchInCsr The value of the miniBatchInCsr attribute + * @return a new instance of GetMinibatchesInCsrWithPhysicalReplica + */ + public GetMinibatchesInCsrWithPhysicalReplica getMinibatchesInCsrWithPhysicalReplica( + Operand programKey, Operand rowIds, Operand colIds, + Operand gains, Operand splits, Operand idCounts, Long sampleCount, + Long numReplica, Long maxMinibatchesPerSc, Long maxIdsPerChipPerSample, Long tableVocabSize, + Long featureWidth, Long numScPerChip, String tableName, String miniBatchInCsr) { + return GetMinibatchesInCsrWithPhysicalReplica.create(scope, programKey, rowIds, colIds, gains, splits, idCounts, sampleCount, numReplica, maxMinibatchesPerSc, maxIdsPerChipPerSample, tableVocabSize, featureWidth, numScPerChip, tableName, miniBatchInCsr); + } + + /** + * The GlobalIterId operation + * + * @return a new instance of GlobalIterId + */ + public GlobalIterId globalIterId() { + return GlobalIterId.create(scope); + } + /** * An op that demultiplexes a tensor to be sharded by XLA to a list of partitioned * outputs outside the XLA computation. Supports ND sharding. @@ -270,6 +368,50 @@ public ShutdownTPUSystem shutdownTPUSystem() { return ShutdownTPUSystem.create(scope); } + /** + * The StoreMinibatchStatisticsInFdo operation + * + * @param programKey The programKey value + * @param maxIds The maxIds value + * @param maxUniques The maxUniques value + * @param sampleCount The value of the sampleCount attribute + * @param numReplica The value of the numReplica attribute + * @param featureWidth The value of the featureWidth attribute + * @param numScPerChip The value of the numScPerChip attribute + * @param tableName The value of the tableName attribute + * @param miniBatchSplits The value of the miniBatchSplits attribute + * @return a new instance of StoreMinibatchStatisticsInFdo + */ + public StoreMinibatchStatisticsInFdo storeMinibatchStatisticsInFdo(Operand programKey, + Operand maxIds, Operand maxUniques, Long sampleCount, Long numReplica, + Long featureWidth, Long numScPerChip, String tableName, String miniBatchSplits) { + return StoreMinibatchStatisticsInFdo.create(scope, programKey, maxIds, maxUniques, sampleCount, numReplica, featureWidth, numScPerChip, tableName, miniBatchSplits); + } + + /** + * The TPUAnnotateTensorsWithDynamicShape operation + * + * @param tensors The tensors value + * @return a new instance of TPUAnnotateTensorsWithDynamicShape + */ + public TPUAnnotateTensorsWithDynamicShape tPUAnnotateTensorsWithDynamicShape( + Iterable> tensors) { + return TPUAnnotateTensorsWithDynamicShape.create(scope, tensors); + } + + /** + * Op that copies host tensor to device with dynamic shape support. + * For internal use only. + * + * @param tensors The tensors value + * @param unpaddedSizes The unpaddedSizes value + * @return a new instance of TPUCopyWithDynamicShape + */ + public TPUCopyWithDynamicShape tPUCopyWithDynamicShape(Iterable> tensors, + Iterable> unpaddedSizes) { + return TPUCopyWithDynamicShape.create(scope, tensors, unpaddedSizes); + } + /** * Round-robin load balancing on TPU cores. * A load balancing op that round-robins among TPU cores. diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java index 1136fb98bf1..116388309ba 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/XlaOps.java @@ -24,6 +24,20 @@ import org.tensorflow.op.xla.XlaHostCompute; import org.tensorflow.op.xla.XlaRecvFromHost; import org.tensorflow.op.xla.XlaSendToHost; +import org.tensorflow.op.xla.XlaSparseCoreAdagrad; +import org.tensorflow.op.xla.XlaSparseCoreAdagradMomentum; +import org.tensorflow.op.xla.XlaSparseCoreAdam; +import org.tensorflow.op.xla.XlaSparseCoreFtrl; +import org.tensorflow.op.xla.XlaSparseCoreSgd; +import org.tensorflow.op.xla.XlaSparseDenseMatmul; +import org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithAdagradAndCsrInput; +import org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput; +import org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithAdamAndCsrInput; +import org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithFtrlAndCsrInput; +import org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithSgdAndCsrInput; +import org.tensorflow.op.xla.XlaSparseDenseMatmulWithCsrInput; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; import org.tensorflow.types.family.TType; /** @@ -94,6 +108,304 @@ public XlaSendToHost xlaSendToHost(Operand input, String key) { return XlaSendToHost.create(scope, input, key); } + /** + * The XlaSparseCoreAdagrad operation + * + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param accumulator The accumulator value + * @param embeddingTable The embeddingTable value + * @param featureWidth The value of the featureWidth attribute + * @return a new instance of XlaSparseCoreAdagrad + */ + public XlaSparseCoreAdagrad xlaSparseCoreAdagrad(Operand indices, + Operand gradient, Operand learningRate, Operand accumulator, + Operand embeddingTable, Long featureWidth) { + return XlaSparseCoreAdagrad.create(scope, indices, gradient, learningRate, accumulator, embeddingTable, featureWidth); + } + + /** + * The XlaSparseCoreAdagradMomentum operation + * + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param beta1 The beta1 value + * @param epsilon The epsilon value + * @param accumulator The accumulator value + * @param momentum The momentum value + * @param embeddingTable The embeddingTable value + * @param featureWidth The value of the featureWidth attribute + * @param useNesterov The value of the useNesterov attribute + * @param beta2 The value of the beta2 attribute + * @param exponent The value of the exponent attribute + * @return a new instance of XlaSparseCoreAdagradMomentum + */ + public XlaSparseCoreAdagradMomentum xlaSparseCoreAdagradMomentum(Operand indices, + Operand gradient, Operand learningRate, Operand beta1, + Operand epsilon, Operand accumulator, Operand momentum, + Operand embeddingTable, Long featureWidth, Boolean useNesterov, Float beta2, + Float exponent) { + return XlaSparseCoreAdagradMomentum.create(scope, indices, gradient, learningRate, beta1, epsilon, accumulator, momentum, embeddingTable, featureWidth, useNesterov, beta2, exponent); + } + + /** + * The XlaSparseCoreAdam operation + * + * @param embeddingTable The embeddingTable value + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param momentum The momentum value + * @param velocity The velocity value + * @param beta1 The beta1 value + * @param beta2 The beta2 value + * @param epsilon The epsilon value + * @param featureWidth The value of the featureWidth attribute + * @param useSumInsideSqrt The value of the useSumInsideSqrt attribute + * @return a new instance of XlaSparseCoreAdam + */ + public XlaSparseCoreAdam xlaSparseCoreAdam(Operand embeddingTable, + Operand indices, Operand gradient, Operand learningRate, + Operand momentum, Operand velocity, Operand beta1, + Operand beta2, Operand epsilon, Long featureWidth, + Boolean useSumInsideSqrt) { + return XlaSparseCoreAdam.create(scope, embeddingTable, indices, gradient, learningRate, momentum, velocity, beta1, beta2, epsilon, featureWidth, useSumInsideSqrt); + } + + /** + * The XlaSparseCoreFtrl operation + * + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param linear The linear value + * @param learningRate The learningRate value + * @param indices The indices value + * @param gradient The gradient value + * @param beta The beta value + * @param learningRatePower The learningRatePower value + * @param l2RegularizationStrength The l2RegularizationStrength value + * @param featureWidth The value of the featureWidth attribute + * @param multiplyLinearByLearningRate The value of the multiplyLinearByLearningRate attribute + * @param l1RegularizationStrength The value of the l1RegularizationStrength attribute + * @return a new instance of XlaSparseCoreFtrl + */ + public XlaSparseCoreFtrl xlaSparseCoreFtrl(Operand embeddingTable, + Operand accumulator, Operand linear, Operand learningRate, + Operand indices, Operand gradient, Operand beta, + Operand learningRatePower, Operand l2RegularizationStrength, + Long featureWidth, Boolean multiplyLinearByLearningRate, Float l1RegularizationStrength) { + return XlaSparseCoreFtrl.create(scope, embeddingTable, accumulator, linear, learningRate, indices, gradient, beta, learningRatePower, l2RegularizationStrength, featureWidth, multiplyLinearByLearningRate, l1RegularizationStrength); + } + + /** + * The XlaSparseCoreSgd operation + * + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param featureWidth The value of the featureWidth attribute + * @return a new instance of XlaSparseCoreSgd + */ + public XlaSparseCoreSgd xlaSparseCoreSgd(Operand indices, Operand gradient, + Operand learningRate, Operand embeddingTable, Long featureWidth) { + return XlaSparseCoreSgd.create(scope, indices, gradient, learningRate, embeddingTable, featureWidth); + } + + /** + * The XlaSparseDenseMatmul operation + * + * @param rowIds The rowIds value + * @param colIds The colIds value + * @param values The values value + * @param offsets The offsets value + * @param embeddingTable The embeddingTable value + * @param maxIdsPerPartition The value of the maxIdsPerPartition attribute + * @param maxUniqueIdsPerPartition The value of the maxUniqueIdsPerPartition attribute + * @param inputSize The value of the inputSize attribute + * @return a new instance of XlaSparseDenseMatmul + */ + public XlaSparseDenseMatmul xlaSparseDenseMatmul(Operand rowIds, + Operand colIds, Operand values, Operand offsets, + Operand embeddingTable, Long maxIdsPerPartition, Long maxUniqueIdsPerPartition, + Long inputSize) { + return XlaSparseDenseMatmul.create(scope, rowIds, colIds, values, offsets, embeddingTable, maxIdsPerPartition, maxUniqueIdsPerPartition, inputSize); + } + + /** + * The XlaSparseDenseMatmulGradWithAdagradAndCsrInput operation + * + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithAdagradAndCsrInput + */ + public XlaSparseDenseMatmulGradWithAdagradAndCsrInput xlaSparseDenseMatmulGradWithAdagradAndCsrInput( + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand accumulator, Operand numMinibatchesPerPhysicalSparseCore, + String tableName, XlaSparseDenseMatmulGradWithAdagradAndCsrInput.Options... options) { + return XlaSparseDenseMatmulGradWithAdagradAndCsrInput.create(scope, rowPointers, sortedSampleIds, sortedTokenIds, sortedGains, activationGradients, learningRate, embeddingTable, accumulator, numMinibatchesPerPhysicalSparseCore, tableName, options); + } + + /** + * The XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput operation + * + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param momenta The momenta value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param useNesterov The value of the useNesterov attribute + * @param exponent The value of the exponent attribute + * @param beta1 The value of the beta1 attribute + * @param beta2 The value of the beta2 attribute + * @param epsilon The value of the epsilon attribute + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput + */ + public XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput xlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput( + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand accumulator, Operand momenta, + Operand numMinibatchesPerPhysicalSparseCore, Boolean useNesterov, Float exponent, + Float beta1, Float beta2, Float epsilon, String tableName, + XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.Options... options) { + return XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.create(scope, rowPointers, sortedSampleIds, sortedTokenIds, sortedGains, activationGradients, learningRate, embeddingTable, accumulator, momenta, numMinibatchesPerPhysicalSparseCore, useNesterov, exponent, beta1, beta2, epsilon, tableName, options); + } + + /** + * The XlaSparseDenseMatmulGradWithAdamAndCsrInput operation + * + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param momenta The momenta value + * @param velocity The velocity value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param useSumInsideSqrt The value of the useSumInsideSqrt attribute + * @param beta1 The value of the beta1 attribute + * @param beta2 The value of the beta2 attribute + * @param epsilon The value of the epsilon attribute + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithAdamAndCsrInput + */ + public XlaSparseDenseMatmulGradWithAdamAndCsrInput xlaSparseDenseMatmulGradWithAdamAndCsrInput( + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, Operand momenta, + Operand velocity, Operand numMinibatchesPerPhysicalSparseCore, + Boolean useSumInsideSqrt, Float beta1, Float beta2, Float epsilon, String tableName, + XlaSparseDenseMatmulGradWithAdamAndCsrInput.Options... options) { + return XlaSparseDenseMatmulGradWithAdamAndCsrInput.create(scope, rowPointers, sortedSampleIds, sortedTokenIds, sortedGains, activationGradients, learningRate, embeddingTable, momenta, velocity, numMinibatchesPerPhysicalSparseCore, useSumInsideSqrt, beta1, beta2, epsilon, tableName, options); + } + + /** + * The XlaSparseDenseMatmulGradWithFtrlAndCsrInput operation + * + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param linear The linear value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param multiplyLinearByLearningRate The value of the multiplyLinearByLearningRate attribute + * @param beta The value of the beta attribute + * @param learningRatePower The value of the learningRatePower attribute + * @param l1RegularizationStrength The value of the l1RegularizationStrength attribute + * @param l2RegularizationStrength The value of the l2RegularizationStrength attribute + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithFtrlAndCsrInput + */ + public XlaSparseDenseMatmulGradWithFtrlAndCsrInput xlaSparseDenseMatmulGradWithFtrlAndCsrInput( + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand accumulator, Operand linear, + Operand numMinibatchesPerPhysicalSparseCore, Boolean multiplyLinearByLearningRate, + Float beta, Float learningRatePower, Float l1RegularizationStrength, + Float l2RegularizationStrength, String tableName, + XlaSparseDenseMatmulGradWithFtrlAndCsrInput.Options... options) { + return XlaSparseDenseMatmulGradWithFtrlAndCsrInput.create(scope, rowPointers, sortedSampleIds, sortedTokenIds, sortedGains, activationGradients, learningRate, embeddingTable, accumulator, linear, numMinibatchesPerPhysicalSparseCore, multiplyLinearByLearningRate, beta, learningRatePower, l1RegularizationStrength, l2RegularizationStrength, tableName, options); + } + + /** + * The XlaSparseDenseMatmulGradWithSgdAndCsrInput operation + * + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithSgdAndCsrInput + */ + public XlaSparseDenseMatmulGradWithSgdAndCsrInput xlaSparseDenseMatmulGradWithSgdAndCsrInput( + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand numMinibatchesPerPhysicalSparseCore, String tableName, + XlaSparseDenseMatmulGradWithSgdAndCsrInput.Options... options) { + return XlaSparseDenseMatmulGradWithSgdAndCsrInput.create(scope, rowPointers, sortedSampleIds, sortedTokenIds, sortedGains, activationGradients, learningRate, embeddingTable, numMinibatchesPerPhysicalSparseCore, tableName, options); + } + + /** + * The XlaSparseDenseMatmulWithCsrInput operation + * + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param embeddingTable The embeddingTable value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param inputSize The value of the inputSize attribute + * @param quantizationConfigLow The value of the quantizationConfigLow attribute + * @param quantizationConfigHigh The value of the quantizationConfigHigh attribute + * @param quantizationConfigNumBuckets The value of the quantizationConfigNumBuckets attribute + * @param tableName The value of the tableName attribute + * @return a new instance of XlaSparseDenseMatmulWithCsrInput + */ + public XlaSparseDenseMatmulWithCsrInput xlaSparseDenseMatmulWithCsrInput( + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand embeddingTable, + Operand numMinibatchesPerPhysicalSparseCore, Long inputSize, + Float quantizationConfigLow, Float quantizationConfigHigh, Long quantizationConfigNumBuckets, + String tableName) { + return XlaSparseDenseMatmulWithCsrInput.create(scope, rowPointers, sortedSampleIds, sortedTokenIds, sortedGains, embeddingTable, numMinibatchesPerPhysicalSparseCore, inputSize, quantizationConfigLow, quantizationConfigHigh, quantizationConfigNumBuckets, tableName); + } + /** * Get the parent {@link Ops} object. */ diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DynamicStitch.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DynamicStitch.java index 258aabce0e0..9aba2968627 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DynamicStitch.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/DynamicStitch.java @@ -54,7 +54,7 @@ * must have {@code data[i].shape = indices[i].shape + constant}. In terms of this * {@code constant}, the output shape is *

- * merged.shape = [max(indices)] + constant
+ * merged.shape = [max(indices) + 1] + constant
  * 
*

Values are merged in order, so if an index appears in both {@code indices[m][i]} and * {@code indices[n][j]} for {@code (m,i) < (n,j)} the slice {@code data[n][j]} will appear in the diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Tile.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Tile.java index fa25cd34464..c9a58b9158c 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Tile.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/core/Tile.java @@ -93,7 +93,7 @@ public Tile(Operation operation) { * Factory method to create a class wrapping a new Tile operation. * * @param scope current scope - * @param input 1-D or higher. + * @param input Can be of any rank. * @param multiples 1-D. Length must be the same as the number of dimensions in {@code input} * @param data type for {@code Tile} output and operands * @return a new instance of Tile @@ -128,7 +128,7 @@ public Output asOutput() { ) public static class Inputs extends RawOpInputs> { /** - * 1-D or higher. + * Can be of any rank. */ public final Operand input; diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetFingerprint.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetFingerprint.java new file mode 100644 index 00000000000..573670b6b26 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/DatasetFingerprint.java @@ -0,0 +1,107 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.family.TType; + +/** + * Returns the fingerprint of {@code input_dataset}. + * Returns the fingerprint of {@code input_dataset}. + */ +@OpMetadata( + opType = DatasetFingerprint.OP_NAME, + inputsClass = DatasetFingerprint.Inputs.class +) +@Operator( + group = "data" +) +public final class DatasetFingerprint extends RawOp implements Operand { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "DatasetFingerprint"; + + private Output fingerprint; + + @SuppressWarnings("unchecked") + public DatasetFingerprint(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + fingerprint = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new DatasetFingerprint operation. + * + * @param scope current scope + * @param inputDataset A variant tensor representing the dataset to return fingerprint for. + * @return a new instance of DatasetFingerprint + */ + @Endpoint( + describeByClass = true + ) + public static DatasetFingerprint create(Scope scope, Operand inputDataset) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "DatasetFingerprint"); + opBuilder.addInput(inputDataset.asOutput()); + return new DatasetFingerprint(opBuilder.build()); + } + + /** + * Gets fingerprint. + * The fingerprint of {@code input_dataset} in {@code uint64} + * @return fingerprint. + */ + public Output fingerprint() { + return fingerprint; + } + + @Override + @SuppressWarnings("unchecked") + public Output asOutput() { + return (Output) fingerprint; + } + + @OpInputsMetadata( + outputsClass = DatasetFingerprint.class + ) + public static class Inputs extends RawOpInputs { + /** + * A variant tensor representing the dataset to return fingerprint for. + */ + public final Operand inputDataset; + + public Inputs(GraphOperation op) { + super(new DatasetFingerprint(op), op, Arrays.asList()); + int inputIndex = 0; + inputDataset = (Operand) op.input(inputIndex++); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ListSnapshotChunksDataset.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ListSnapshotChunksDataset.java new file mode 100644 index 00000000000..0fe1bbb447b --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/data/ListSnapshotChunksDataset.java @@ -0,0 +1,132 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.data; + +import java.util.Arrays; +import java.util.List; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.ndarray.Shape; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.proto.DataType; +import org.tensorflow.types.TString; +import org.tensorflow.types.family.TType; + +/** + * The ListSnapshotChunksDataset operation + */ +@OpMetadata( + opType = ListSnapshotChunksDataset.OP_NAME, + inputsClass = ListSnapshotChunksDataset.Inputs.class +) +@Operator( + group = "data" +) +public final class ListSnapshotChunksDataset extends RawOp implements Operand { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ListSnapshotChunksDataset"; + + private Output handle; + + @SuppressWarnings("unchecked") + public ListSnapshotChunksDataset(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + handle = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ListSnapshotChunksDataset operation. + * + * @param scope current scope + * @param snapshotPath The snapshotPath value + * @param outputTypes The value of the outputTypes attribute + * @param outputShapes The value of the outputShapes attribute + * @return a new instance of ListSnapshotChunksDataset + */ + @Endpoint( + describeByClass = true + ) + public static ListSnapshotChunksDataset create(Scope scope, Operand snapshotPath, + List> outputTypes, List outputShapes) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "ListSnapshotChunksDataset"); + opBuilder.addInput(snapshotPath.asOutput()); + opBuilder.setAttr("output_types", Operands.toDataTypes(outputTypes)); + Shape[] outputShapesArray = new Shape[outputShapes.size()]; + for (int i = 0 ; i < outputShapesArray.length ; i++) { + outputShapesArray[i] = outputShapes.get(i); + } + opBuilder.setAttr("output_shapes", outputShapesArray); + return new ListSnapshotChunksDataset(opBuilder.build()); + } + + /** + * Gets handle. + * + * @return handle. + */ + public Output handle() { + return handle; + } + + @Override + @SuppressWarnings("unchecked") + public Output asOutput() { + return (Output) handle; + } + + @OpInputsMetadata( + outputsClass = ListSnapshotChunksDataset.class + ) + public static class Inputs extends RawOpInputs { + /** + * The snapshotPath input + */ + public final Operand snapshotPath; + + /** + * The outputTypes attribute + */ + public final DataType[] outputTypes; + + /** + * The outputShapes attribute + */ + public final Shape[] outputShapes; + + public Inputs(GraphOperation op) { + super(new ListSnapshotChunksDataset(op), op, Arrays.asList("output_types", "output_shapes")); + int inputIndex = 0; + snapshotPath = (Operand) op.input(inputIndex++); + outputTypes = op.attributes().getAttrTypeList("output_types"); + outputShapes = op.attributes().getAttrShapeList("output_shapes"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/MatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/MatMul.java index 90a74daac9f..a984a502a7a 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/MatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/linalg/MatMul.java @@ -91,6 +91,12 @@ public static MatMul create(Scope scope, Operand a, Oper if (opts.transposeB != null) { opBuilder.setAttr("transpose_b", opts.transposeB); } + if (opts.gradA != null) { + opBuilder.setAttr("grad_a", opts.gradA); + } + if (opts.gradB != null) { + opBuilder.setAttr("grad_b", opts.gradB); + } } } return new MatMul<>(opBuilder.build()); @@ -116,6 +122,26 @@ public static Options transposeB(Boolean transposeB) { return new Options().transposeB(transposeB); } + /** + * Sets the gradA option. + * + * @param gradA the gradA option + * @return this Options instance. + */ + public static Options gradA(Boolean gradA) { + return new Options().gradA(gradA); + } + + /** + * Sets the gradB option. + * + * @param gradB the gradB option + * @return this Options instance. + */ + public static Options gradB(Boolean gradB) { + return new Options().gradB(gradB); + } + /** * Gets product. * @@ -138,6 +164,10 @@ public static class Options { private Boolean transposeB; + private Boolean gradA; + + private Boolean gradB; + private Options() { } @@ -162,6 +192,28 @@ public Options transposeB(Boolean transposeB) { this.transposeB = transposeB; return this; } + + /** + * Sets the gradA option. + * + * @param gradA the gradA option + * @return this Options instance. + */ + public Options gradA(Boolean gradA) { + this.gradA = gradA; + return this; + } + + /** + * Sets the gradB option. + * + * @param gradB the gradB option + * @return this Options instance. + */ + public Options gradB(Boolean gradB) { + this.gradB = gradB; + return this; + } } @OpInputsMetadata( @@ -193,14 +245,26 @@ public static class Inputs extends RawOpInputs> { */ public final DataType T; + /** + * The gradA attribute + */ + public final boolean gradA; + + /** + * The gradB attribute + */ + public final boolean gradB; + public Inputs(GraphOperation op) { - super(new MatMul<>(op), op, Arrays.asList("transpose_a", "transpose_b", "T")); + super(new MatMul<>(op), op, Arrays.asList("transpose_a", "transpose_b", "T", "grad_a", "grad_b")); int inputIndex = 0; a = (Operand) op.input(inputIndex++); b = (Operand) op.input(inputIndex++); transposeA = op.attributes().getAttrBool("transpose_a"); transposeB = op.attributes().getAttrBool("transpose_b"); T = op.attributes().getAttrType("T"); + gradA = op.attributes().getAttrBool("grad_a"); + gradB = op.attributes().getAttrBool("grad_b"); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ComputeDedupDataSize.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ComputeDedupDataSize.java new file mode 100644 index 00000000000..4305affd43e --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ComputeDedupDataSize.java @@ -0,0 +1,107 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt32; + +/** + * An op computes the size of the deduplication data from embedding core and returns the updated config. + * This op is to compute size of the deduplication data so to provide this + * information to the op that computes the tuple mask of deduplication data can + * have static output shape. + */ +@OpMetadata( + opType = ComputeDedupDataSize.OP_NAME, + inputsClass = ComputeDedupDataSize.Inputs.class +) +@Operator( + group = "tpu" +) +public final class ComputeDedupDataSize extends RawOp implements Operand { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ComputeDedupDataSize"; + + private Output numElements; + + public ComputeDedupDataSize(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + numElements = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ComputeDedupDataSize operation. + * + * @param scope current scope + * @param config Serialized TPUEmbeddingConfiguration proto. + * @return a new instance of ComputeDedupDataSize + */ + @Endpoint( + describeByClass = true + ) + public static ComputeDedupDataSize create(Scope scope, String config) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "ComputeDedupDataSize"); + opBuilder.setAttr("config", config); + return new ComputeDedupDataSize(opBuilder.build()); + } + + /** + * Gets numElements. + * The size of the deduplicated data from infeed. + * @return numElements. + */ + public Output numElements() { + return numElements; + } + + @Override + public Output asOutput() { + return numElements; + } + + @OpInputsMetadata( + outputsClass = ComputeDedupDataSize.class + ) + public static class Inputs extends RawOpInputs { + /** + * Serialized TPUEmbeddingConfiguration proto. + */ + public final String config; + + public Inputs(GraphOperation op) { + super(new ComputeDedupDataSize(op), op, Arrays.asList("config")); + int inputIndex = 0; + config = op.attributes().getAttrString("config"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConvertToCooTensor.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConvertToCooTensor.java new file mode 100644 index 00000000000..efec4caa44a --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/ConvertToCooTensor.java @@ -0,0 +1,157 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The ConvertToCooTensor operation + */ +@OpMetadata( + opType = ConvertToCooTensor.OP_NAME, + inputsClass = ConvertToCooTensor.Inputs.class +) +@Operator( + group = "tpu" +) +public final class ConvertToCooTensor extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "ConvertToCooTensor"; + + private Output rowIds; + + private Output colIds; + + private Output gains; + + public ConvertToCooTensor(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + rowIds = operation.output(outputIdx++); + colIds = operation.output(outputIdx++); + gains = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new ConvertToCooTensor operation. + * + * @param scope current scope + * @param indicesOrRowSplits The indicesOrRowSplits value + * @param values The values value + * @param weights The weights value + * @param sampleCount The value of the sampleCount attribute + * @param combiner The value of the combiner attribute + * @return a new instance of ConvertToCooTensor + */ + @Endpoint( + describeByClass = true + ) + public static ConvertToCooTensor create(Scope scope, Operand indicesOrRowSplits, + Operand values, Operand weights, Long sampleCount, String combiner) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "ConvertToCooTensor"); + opBuilder.addInput(indicesOrRowSplits.asOutput()); + opBuilder.addInput(values.asOutput()); + opBuilder.addInput(weights.asOutput()); + opBuilder.setAttr("sample_count", sampleCount); + opBuilder.setAttr("combiner", combiner); + return new ConvertToCooTensor(opBuilder.build()); + } + + /** + * Gets rowIds. + * + * @return rowIds. + */ + public Output rowIds() { + return rowIds; + } + + /** + * Gets colIds. + * + * @return colIds. + */ + public Output colIds() { + return colIds; + } + + /** + * Gets gains. + * + * @return gains. + */ + public Output gains() { + return gains; + } + + @OpInputsMetadata( + outputsClass = ConvertToCooTensor.class + ) + public static class Inputs extends RawOpInputs { + /** + * The indicesOrRowSplits input + */ + public final Operand indicesOrRowSplits; + + /** + * The values input + */ + public final Operand values; + + /** + * The weights input + */ + public final Operand weights; + + /** + * The sampleCount attribute + */ + public final long sampleCount; + + /** + * The combiner attribute + */ + public final String combiner; + + public Inputs(GraphOperation op) { + super(new ConvertToCooTensor(op), op, Arrays.asList("sample_count", "combiner")); + int inputIndex = 0; + indicesOrRowSplits = (Operand) op.input(inputIndex++); + values = (Operand) op.input(inputIndex++); + weights = (Operand) op.input(inputIndex++); + sampleCount = op.attributes().getAttrInt("sample_count"); + combiner = op.attributes().getAttrString("combiner"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GetMinibatchSplitsWithPhysicalReplica.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GetMinibatchSplitsWithPhysicalReplica.java new file mode 100644 index 00000000000..7746ebadb48 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GetMinibatchSplitsWithPhysicalReplica.java @@ -0,0 +1,257 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; + +/** + * The GetMinibatchSplitsWithPhysicalReplica operation + */ +@OpMetadata( + opType = GetMinibatchSplitsWithPhysicalReplica.OP_NAME, + inputsClass = GetMinibatchSplitsWithPhysicalReplica.Inputs.class +) +@Operator( + group = "tpu" +) +public final class GetMinibatchSplitsWithPhysicalReplica extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "GetMinibatchSplitsWithPhysicalReplica"; + + private Output sortedRowIds; + + private Output sortedColIds; + + private Output sortedGains; + + private Output splits; + + private Output idCounts; + + private Output maxIds; + + private Output maxUniques; + + public GetMinibatchSplitsWithPhysicalReplica(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + sortedRowIds = operation.output(outputIdx++); + sortedColIds = operation.output(outputIdx++); + sortedGains = operation.output(outputIdx++); + splits = operation.output(outputIdx++); + idCounts = operation.output(outputIdx++); + maxIds = operation.output(outputIdx++); + maxUniques = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new GetMinibatchSplitsWithPhysicalReplica operation. + * + * @param scope current scope + * @param programKey The programKey value + * @param rowIds The rowIds value + * @param colIds The colIds value + * @param gains The gains value + * @param sampleCount The value of the sampleCount attribute + * @param numReplica The value of the numReplica attribute + * @param tableVocabSize The value of the tableVocabSize attribute + * @param featureWidth The value of the featureWidth attribute + * @param numScPerChip The value of the numScPerChip attribute + * @param tableName The value of the tableName attribute + * @param miniBatchSplits The value of the miniBatchSplits attribute + * @return a new instance of GetMinibatchSplitsWithPhysicalReplica + */ + @Endpoint( + describeByClass = true + ) + public static GetMinibatchSplitsWithPhysicalReplica create(Scope scope, + Operand programKey, Operand rowIds, Operand colIds, + Operand gains, Long sampleCount, Long numReplica, Long tableVocabSize, + Long featureWidth, Long numScPerChip, String tableName, String miniBatchSplits) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "GetMinibatchSplitsWithPhysicalReplica"); + opBuilder.addInput(programKey.asOutput()); + opBuilder.addInput(rowIds.asOutput()); + opBuilder.addInput(colIds.asOutput()); + opBuilder.addInput(gains.asOutput()); + opBuilder.setAttr("sample_count", sampleCount); + opBuilder.setAttr("num_replica", numReplica); + opBuilder.setAttr("table_vocab_size", tableVocabSize); + opBuilder.setAttr("feature_width", featureWidth); + opBuilder.setAttr("num_sc_per_chip", numScPerChip); + opBuilder.setAttr("table_name", tableName); + opBuilder.setAttr("mini_batch_splits", miniBatchSplits); + return new GetMinibatchSplitsWithPhysicalReplica(opBuilder.build()); + } + + /** + * Gets sortedRowIds. + * + * @return sortedRowIds. + */ + public Output sortedRowIds() { + return sortedRowIds; + } + + /** + * Gets sortedColIds. + * + * @return sortedColIds. + */ + public Output sortedColIds() { + return sortedColIds; + } + + /** + * Gets sortedGains. + * + * @return sortedGains. + */ + public Output sortedGains() { + return sortedGains; + } + + /** + * Gets splits. + * + * @return splits. + */ + public Output splits() { + return splits; + } + + /** + * Gets idCounts. + * + * @return idCounts. + */ + public Output idCounts() { + return idCounts; + } + + /** + * Gets maxIds. + * + * @return maxIds. + */ + public Output maxIds() { + return maxIds; + } + + /** + * Gets maxUniques. + * + * @return maxUniques. + */ + public Output maxUniques() { + return maxUniques; + } + + @OpInputsMetadata( + outputsClass = GetMinibatchSplitsWithPhysicalReplica.class + ) + public static class Inputs extends RawOpInputs { + /** + * The programKey input + */ + public final Operand programKey; + + /** + * The rowIds input + */ + public final Operand rowIds; + + /** + * The colIds input + */ + public final Operand colIds; + + /** + * The gains input + */ + public final Operand gains; + + /** + * The sampleCount attribute + */ + public final long sampleCount; + + /** + * The numReplica attribute + */ + public final long numReplica; + + /** + * The tableVocabSize attribute + */ + public final long tableVocabSize; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + /** + * The numScPerChip attribute + */ + public final long numScPerChip; + + /** + * The tableName attribute + */ + public final String tableName; + + /** + * The miniBatchSplits attribute + */ + public final String miniBatchSplits; + + public Inputs(GraphOperation op) { + super(new GetMinibatchSplitsWithPhysicalReplica(op), op, Arrays.asList("sample_count", "num_replica", "table_vocab_size", "feature_width", "num_sc_per_chip", "table_name", "mini_batch_splits")); + int inputIndex = 0; + programKey = (Operand) op.input(inputIndex++); + rowIds = (Operand) op.input(inputIndex++); + colIds = (Operand) op.input(inputIndex++); + gains = (Operand) op.input(inputIndex++); + sampleCount = op.attributes().getAttrInt("sample_count"); + numReplica = op.attributes().getAttrInt("num_replica"); + tableVocabSize = op.attributes().getAttrInt("table_vocab_size"); + featureWidth = op.attributes().getAttrInt("feature_width"); + numScPerChip = op.attributes().getAttrInt("num_sc_per_chip"); + tableName = op.attributes().getAttrString("table_name"); + miniBatchSplits = op.attributes().getAttrString("mini_batch_splits"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GetMinibatchesInCsrWithPhysicalReplica.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GetMinibatchesInCsrWithPhysicalReplica.java new file mode 100644 index 00000000000..d51f1c3959f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GetMinibatchesInCsrWithPhysicalReplica.java @@ -0,0 +1,290 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TInt64; +import org.tensorflow.types.TString; + +/** + * The GetMinibatchesInCsrWithPhysicalReplica operation + */ +@OpMetadata( + opType = GetMinibatchesInCsrWithPhysicalReplica.OP_NAME, + inputsClass = GetMinibatchesInCsrWithPhysicalReplica.Inputs.class +) +@Operator( + group = "tpu" +) +public final class GetMinibatchesInCsrWithPhysicalReplica extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "GetMinibatchesInCsrWithPhysicalReplica"; + + private Output rowPointers; + + private Output sortedSampleIds; + + private Output sortedTokenIds; + + private Output sortedGains; + + private Output rowPointersUnpaddedSize; + + private Output idsUnpaddedSize; + + private Output numMinibatchesPerPhysicalSparseCore; + + public GetMinibatchesInCsrWithPhysicalReplica(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + rowPointers = operation.output(outputIdx++); + sortedSampleIds = operation.output(outputIdx++); + sortedTokenIds = operation.output(outputIdx++); + sortedGains = operation.output(outputIdx++); + rowPointersUnpaddedSize = operation.output(outputIdx++); + idsUnpaddedSize = operation.output(outputIdx++); + numMinibatchesPerPhysicalSparseCore = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new GetMinibatchesInCsrWithPhysicalReplica operation. + * + * @param scope current scope + * @param programKey The programKey value + * @param rowIds The rowIds value + * @param colIds The colIds value + * @param gains The gains value + * @param splits The splits value + * @param idCounts The idCounts value + * @param sampleCount The value of the sampleCount attribute + * @param numReplica The value of the numReplica attribute + * @param maxMinibatchesPerSc The value of the maxMinibatchesPerSc attribute + * @param maxIdsPerChipPerSample The value of the maxIdsPerChipPerSample attribute + * @param tableVocabSize The value of the tableVocabSize attribute + * @param featureWidth The value of the featureWidth attribute + * @param numScPerChip The value of the numScPerChip attribute + * @param tableName The value of the tableName attribute + * @param miniBatchInCsr The value of the miniBatchInCsr attribute + * @return a new instance of GetMinibatchesInCsrWithPhysicalReplica + */ + @Endpoint( + describeByClass = true + ) + public static GetMinibatchesInCsrWithPhysicalReplica create(Scope scope, + Operand programKey, Operand rowIds, Operand colIds, + Operand gains, Operand splits, Operand idCounts, Long sampleCount, + Long numReplica, Long maxMinibatchesPerSc, Long maxIdsPerChipPerSample, Long tableVocabSize, + Long featureWidth, Long numScPerChip, String tableName, String miniBatchInCsr) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "GetMinibatchesInCsrWithPhysicalReplica"); + opBuilder.addInput(programKey.asOutput()); + opBuilder.addInput(rowIds.asOutput()); + opBuilder.addInput(colIds.asOutput()); + opBuilder.addInput(gains.asOutput()); + opBuilder.addInput(splits.asOutput()); + opBuilder.addInput(idCounts.asOutput()); + opBuilder.setAttr("sample_count", sampleCount); + opBuilder.setAttr("num_replica", numReplica); + opBuilder.setAttr("max_minibatches_per_sc", maxMinibatchesPerSc); + opBuilder.setAttr("max_ids_per_chip_per_sample", maxIdsPerChipPerSample); + opBuilder.setAttr("table_vocab_size", tableVocabSize); + opBuilder.setAttr("feature_width", featureWidth); + opBuilder.setAttr("num_sc_per_chip", numScPerChip); + opBuilder.setAttr("table_name", tableName); + opBuilder.setAttr("mini_batch_in_csr", miniBatchInCsr); + return new GetMinibatchesInCsrWithPhysicalReplica(opBuilder.build()); + } + + /** + * Gets rowPointers. + * + * @return rowPointers. + */ + public Output rowPointers() { + return rowPointers; + } + + /** + * Gets sortedSampleIds. + * + * @return sortedSampleIds. + */ + public Output sortedSampleIds() { + return sortedSampleIds; + } + + /** + * Gets sortedTokenIds. + * + * @return sortedTokenIds. + */ + public Output sortedTokenIds() { + return sortedTokenIds; + } + + /** + * Gets sortedGains. + * + * @return sortedGains. + */ + public Output sortedGains() { + return sortedGains; + } + + /** + * Gets rowPointersUnpaddedSize. + * + * @return rowPointersUnpaddedSize. + */ + public Output rowPointersUnpaddedSize() { + return rowPointersUnpaddedSize; + } + + /** + * Gets idsUnpaddedSize. + * + * @return idsUnpaddedSize. + */ + public Output idsUnpaddedSize() { + return idsUnpaddedSize; + } + + /** + * Gets numMinibatchesPerPhysicalSparseCore. + * + * @return numMinibatchesPerPhysicalSparseCore. + */ + public Output numMinibatchesPerPhysicalSparseCore() { + return numMinibatchesPerPhysicalSparseCore; + } + + @OpInputsMetadata( + outputsClass = GetMinibatchesInCsrWithPhysicalReplica.class + ) + public static class Inputs extends RawOpInputs { + /** + * The programKey input + */ + public final Operand programKey; + + /** + * The rowIds input + */ + public final Operand rowIds; + + /** + * The colIds input + */ + public final Operand colIds; + + /** + * The gains input + */ + public final Operand gains; + + /** + * The splits input + */ + public final Operand splits; + + /** + * The idCounts input + */ + public final Operand idCounts; + + /** + * The sampleCount attribute + */ + public final long sampleCount; + + /** + * The numReplica attribute + */ + public final long numReplica; + + /** + * The maxMinibatchesPerSc attribute + */ + public final long maxMinibatchesPerSc; + + /** + * The maxIdsPerChipPerSample attribute + */ + public final long maxIdsPerChipPerSample; + + /** + * The tableVocabSize attribute + */ + public final long tableVocabSize; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + /** + * The numScPerChip attribute + */ + public final long numScPerChip; + + /** + * The tableName attribute + */ + public final String tableName; + + /** + * The miniBatchInCsr attribute + */ + public final String miniBatchInCsr; + + public Inputs(GraphOperation op) { + super(new GetMinibatchesInCsrWithPhysicalReplica(op), op, Arrays.asList("sample_count", "num_replica", "max_minibatches_per_sc", "max_ids_per_chip_per_sample", "table_vocab_size", "feature_width", "num_sc_per_chip", "table_name", "mini_batch_in_csr")); + int inputIndex = 0; + programKey = (Operand) op.input(inputIndex++); + rowIds = (Operand) op.input(inputIndex++); + colIds = (Operand) op.input(inputIndex++); + gains = (Operand) op.input(inputIndex++); + splits = (Operand) op.input(inputIndex++); + idCounts = (Operand) op.input(inputIndex++); + sampleCount = op.attributes().getAttrInt("sample_count"); + numReplica = op.attributes().getAttrInt("num_replica"); + maxMinibatchesPerSc = op.attributes().getAttrInt("max_minibatches_per_sc"); + maxIdsPerChipPerSample = op.attributes().getAttrInt("max_ids_per_chip_per_sample"); + tableVocabSize = op.attributes().getAttrInt("table_vocab_size"); + featureWidth = op.attributes().getAttrInt("feature_width"); + numScPerChip = op.attributes().getAttrInt("num_sc_per_chip"); + tableName = op.attributes().getAttrString("table_name"); + miniBatchInCsr = op.attributes().getAttrString("mini_batch_in_csr"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GlobalIterId.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GlobalIterId.java new file mode 100644 index 00000000000..f0f71accb37 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/GlobalIterId.java @@ -0,0 +1,96 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt64; + +/** + * The GlobalIterId operation + */ +@OpMetadata( + opType = GlobalIterId.OP_NAME, + inputsClass = GlobalIterId.Inputs.class +) +@Operator( + group = "tpu" +) +public final class GlobalIterId extends RawOp implements Operand { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "GlobalIterId"; + + private Output iterId; + + public GlobalIterId(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + iterId = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new GlobalIterId operation. + * + * @param scope current scope + * @return a new instance of GlobalIterId + */ + @Endpoint( + describeByClass = true + ) + public static GlobalIterId create(Scope scope) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "GlobalIterId"); + return new GlobalIterId(opBuilder.build()); + } + + /** + * Gets iterId. + * + * @return iterId. + */ + public Output iterId() { + return iterId; + } + + @Override + public Output asOutput() { + return iterId; + } + + @OpInputsMetadata( + outputsClass = GlobalIterId.class + ) + public static class Inputs extends RawOpInputs { + public Inputs(GraphOperation op) { + super(new GlobalIterId(op), op, Arrays.asList()); + int inputIndex = 0; + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/StoreMinibatchStatisticsInFdo.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/StoreMinibatchStatisticsInFdo.java new file mode 100644 index 00000000000..a3c05fd31fb --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/StoreMinibatchStatisticsInFdo.java @@ -0,0 +1,152 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.TString; + +/** + * The StoreMinibatchStatisticsInFdo operation + */ +@OpMetadata( + opType = StoreMinibatchStatisticsInFdo.OP_NAME, + inputsClass = StoreMinibatchStatisticsInFdo.Inputs.class +) +@Operator( + group = "tpu" +) +public final class StoreMinibatchStatisticsInFdo extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "StoreMinibatchStatisticsInFdo"; + + public StoreMinibatchStatisticsInFdo(Operation operation) { + super(operation, OP_NAME); + } + + /** + * Factory method to create a class wrapping a new StoreMinibatchStatisticsInFdo operation. + * + * @param scope current scope + * @param programKey The programKey value + * @param maxIds The maxIds value + * @param maxUniques The maxUniques value + * @param sampleCount The value of the sampleCount attribute + * @param numReplica The value of the numReplica attribute + * @param featureWidth The value of the featureWidth attribute + * @param numScPerChip The value of the numScPerChip attribute + * @param tableName The value of the tableName attribute + * @param miniBatchSplits The value of the miniBatchSplits attribute + * @return a new instance of StoreMinibatchStatisticsInFdo + */ + @Endpoint( + describeByClass = true + ) + public static StoreMinibatchStatisticsInFdo create(Scope scope, Operand programKey, + Operand maxIds, Operand maxUniques, Long sampleCount, Long numReplica, + Long featureWidth, Long numScPerChip, String tableName, String miniBatchSplits) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "StoreMinibatchStatisticsInFdo"); + opBuilder.addInput(programKey.asOutput()); + opBuilder.addInput(maxIds.asOutput()); + opBuilder.addInput(maxUniques.asOutput()); + opBuilder.setAttr("sample_count", sampleCount); + opBuilder.setAttr("num_replica", numReplica); + opBuilder.setAttr("feature_width", featureWidth); + opBuilder.setAttr("num_sc_per_chip", numScPerChip); + opBuilder.setAttr("table_name", tableName); + opBuilder.setAttr("mini_batch_splits", miniBatchSplits); + return new StoreMinibatchStatisticsInFdo(opBuilder.build()); + } + + @OpInputsMetadata( + outputsClass = StoreMinibatchStatisticsInFdo.class + ) + public static class Inputs extends RawOpInputs { + /** + * The programKey input + */ + public final Operand programKey; + + /** + * The maxIds input + */ + public final Operand maxIds; + + /** + * The maxUniques input + */ + public final Operand maxUniques; + + /** + * The sampleCount attribute + */ + public final long sampleCount; + + /** + * The numReplica attribute + */ + public final long numReplica; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + /** + * The numScPerChip attribute + */ + public final long numScPerChip; + + /** + * The tableName attribute + */ + public final String tableName; + + /** + * The miniBatchSplits attribute + */ + public final String miniBatchSplits; + + public Inputs(GraphOperation op) { + super(new StoreMinibatchStatisticsInFdo(op), op, Arrays.asList("sample_count", "num_replica", "feature_width", "num_sc_per_chip", "table_name", "mini_batch_splits")); + int inputIndex = 0; + programKey = (Operand) op.input(inputIndex++); + maxIds = (Operand) op.input(inputIndex++); + maxUniques = (Operand) op.input(inputIndex++); + sampleCount = op.attributes().getAttrInt("sample_count"); + numReplica = op.attributes().getAttrInt("num_replica"); + featureWidth = op.attributes().getAttrInt("feature_width"); + numScPerChip = op.attributes().getAttrInt("num_sc_per_chip"); + tableName = op.attributes().getAttrString("table_name"); + miniBatchSplits = op.attributes().getAttrString("mini_batch_splits"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUAnnotateTensorsWithDynamicShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUAnnotateTensorsWithDynamicShape.java new file mode 100644 index 00000000000..a4dc34f7fc4 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUAnnotateTensorsWithDynamicShape.java @@ -0,0 +1,121 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.proto.DataType; +import org.tensorflow.types.family.TType; + +/** + * The TPUAnnotateTensorsWithDynamicShape operation + */ +@OpMetadata( + opType = TPUAnnotateTensorsWithDynamicShape.OP_NAME, + inputsClass = TPUAnnotateTensorsWithDynamicShape.Inputs.class +) +@Operator( + group = "tpu" +) +public final class TPUAnnotateTensorsWithDynamicShape extends RawOp implements Iterable> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "TPUAnnotateTensorsWithDynamicShape"; + + private List> tpuTensors; + + @SuppressWarnings("unchecked") + public TPUAnnotateTensorsWithDynamicShape(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + int tpuTensorsLength = operation.outputListLength("tpu_tensors"); + tpuTensors = Arrays.asList(operation.outputList(outputIdx, tpuTensorsLength)); + outputIdx += tpuTensorsLength; + } + + /** + * Factory method to create a class wrapping a new TPUAnnotateTensorsWithDynamicShape operation. + * + * @param scope current scope + * @param tensors The tensors value + * @return a new instance of TPUAnnotateTensorsWithDynamicShape + */ + @Endpoint( + describeByClass = true + ) + public static TPUAnnotateTensorsWithDynamicShape create(Scope scope, + Iterable> tensors) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "TPUAnnotateTensorsWithDynamicShape"); + opBuilder.addInputList(Operands.asOutputs(tensors)); + return new TPUAnnotateTensorsWithDynamicShape(opBuilder.build()); + } + + /** + * Gets tpuTensors. + * + * @return tpuTensors. + */ + public List> tpuTensors() { + return tpuTensors; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator> iterator() { + return (Iterator) tpuTensors.iterator(); + } + + @OpInputsMetadata( + outputsClass = TPUAnnotateTensorsWithDynamicShape.class + ) + public static class Inputs extends RawOpInputs { + /** + * The tensors input + */ + public final Iterable> tensors; + + /** + * The T attribute + */ + public final DataType[] T; + + public Inputs(GraphOperation op) { + super(new TPUAnnotateTensorsWithDynamicShape(op), op, Arrays.asList("T")); + int inputIndex = 0; + int tensorsLength = op.inputListLength("tensors"); + tensors = Arrays.asList((Operand[]) op.inputList(inputIndex, tensorsLength)); + inputIndex += tensorsLength; + T = op.attributes().getAttrTypeList("T"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUCopyWithDynamicShape.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUCopyWithDynamicShape.java new file mode 100644 index 00000000000..3e79b7672b5 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/tpu/TPUCopyWithDynamicShape.java @@ -0,0 +1,133 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.tpu; + +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.Operands; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.proto.DataType; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * Op that copies host tensor to device with dynamic shape support. + * For internal use only. + */ +@OpMetadata( + opType = TPUCopyWithDynamicShape.OP_NAME, + inputsClass = TPUCopyWithDynamicShape.Inputs.class +) +@Operator( + group = "tpu" +) +public final class TPUCopyWithDynamicShape extends RawOp implements Iterable> { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "TPUCopyWithDynamicShape"; + + private List> tpuTensors; + + @SuppressWarnings("unchecked") + public TPUCopyWithDynamicShape(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + int tpuTensorsLength = operation.outputListLength("tpu_tensors"); + tpuTensors = Arrays.asList(operation.outputList(outputIdx, tpuTensorsLength)); + outputIdx += tpuTensorsLength; + } + + /** + * Factory method to create a class wrapping a new TPUCopyWithDynamicShape operation. + * + * @param scope current scope + * @param tensors The tensors value + * @param unpaddedSizes The unpaddedSizes value + * @return a new instance of TPUCopyWithDynamicShape + */ + @Endpoint( + describeByClass = true + ) + public static TPUCopyWithDynamicShape create(Scope scope, Iterable> tensors, + Iterable> unpaddedSizes) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "TPUCopyWithDynamicShape"); + opBuilder.addInputList(Operands.asOutputs(tensors)); + opBuilder.addInputList(Operands.asOutputs(unpaddedSizes)); + return new TPUCopyWithDynamicShape(opBuilder.build()); + } + + /** + * Gets tpuTensors. + * + * @return tpuTensors. + */ + public List> tpuTensors() { + return tpuTensors; + } + + @Override + @SuppressWarnings({"rawtypes", "unchecked"}) + public Iterator> iterator() { + return (Iterator) tpuTensors.iterator(); + } + + @OpInputsMetadata( + outputsClass = TPUCopyWithDynamicShape.class + ) + public static class Inputs extends RawOpInputs { + /** + * The tensors input + */ + public final Iterable> tensors; + + /** + * The unpaddedSizes input + */ + public final Iterable> unpaddedSizes; + + /** + * The T attribute + */ + public final DataType[] T; + + public Inputs(GraphOperation op) { + super(new TPUCopyWithDynamicShape(op), op, Arrays.asList("T")); + int inputIndex = 0; + int tensorsLength = op.inputListLength("tensors"); + tensors = Arrays.asList((Operand[]) op.inputList(inputIndex, tensorsLength)); + inputIndex += tensorsLength; + int unpaddedSizesLength = op.inputListLength("unpadded_sizes"); + unpaddedSizes = Arrays.asList((Operand[]) op.inputList(inputIndex, unpaddedSizesLength)); + inputIndex += unpaddedSizesLength; + T = op.attributes().getAttrTypeList("T"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/BatchMatMul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/BatchMatMul.java index f73f7c8927d..479f29deaf2 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/BatchMatMul.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/train/BatchMatMul.java @@ -108,6 +108,12 @@ public static BatchMatMul create(Scope scope, Operand(opBuilder.build()); @@ -133,6 +139,26 @@ public static Options adjY(Boolean adjY) { return new Options().adjY(adjY); } + /** + * Sets the gradX option. + * + * @param gradX the gradX option + * @return this Options instance. + */ + public static Options gradX(Boolean gradX) { + return new Options().gradX(gradX); + } + + /** + * Sets the gradY option. + * + * @param gradY the gradY option + * @return this Options instance. + */ + public static Options gradY(Boolean gradY) { + return new Options().gradY(gradY); + } + /** * Gets output. * 3-D or higher with shape {@code [..., r_o, c_o]} @@ -155,6 +181,10 @@ public static class Options { private Boolean adjY; + private Boolean gradX; + + private Boolean gradY; + private Options() { } @@ -179,6 +209,28 @@ public Options adjY(Boolean adjY) { this.adjY = adjY; return this; } + + /** + * Sets the gradX option. + * + * @param gradX the gradX option + * @return this Options instance. + */ + public Options gradX(Boolean gradX) { + this.gradX = gradX; + return this; + } + + /** + * Sets the gradY option. + * + * @param gradY the gradY option + * @return this Options instance. + */ + public Options gradY(Boolean gradY) { + this.gradY = gradY; + return this; + } } @OpInputsMetadata( @@ -220,8 +272,18 @@ public static class Inputs extends RawOpInputs> { */ public final boolean adjY; + /** + * The gradX attribute + */ + public final boolean gradX; + + /** + * The gradY attribute + */ + public final boolean gradY; + public Inputs(GraphOperation op) { - super(new BatchMatMul<>(op), op, Arrays.asList("Ta", "Tb", "Tout", "adj_x", "adj_y")); + super(new BatchMatMul<>(op), op, Arrays.asList("Ta", "Tb", "Tout", "adj_x", "adj_y", "grad_x", "grad_y")); int inputIndex = 0; x = (Operand) op.input(inputIndex++); y = (Operand) op.input(inputIndex++); @@ -230,6 +292,8 @@ public Inputs(GraphOperation op) { Tout = op.attributes().getAttrType("Tout"); adjX = op.attributes().getAttrBool("adj_x"); adjY = op.attributes().getAttrBool("adj_y"); + gradX = op.attributes().getAttrBool("grad_x"); + gradY = op.attributes().getAttrBool("grad_y"); } } } diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdagrad.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdagrad.java new file mode 100644 index 00000000000..f4e7db1771e --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdagrad.java @@ -0,0 +1,154 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseCoreAdagrad operation + */ +@OpMetadata( + opType = XlaSparseCoreAdagrad.OP_NAME, + inputsClass = XlaSparseCoreAdagrad.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseCoreAdagrad extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseCoreAdagrad"; + + private Output updatedEmbeddingTable; + + private Output updatedAccumulator; + + public XlaSparseCoreAdagrad(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedAccumulator = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseCoreAdagrad operation. + * + * @param scope current scope + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param accumulator The accumulator value + * @param embeddingTable The embeddingTable value + * @param featureWidth The value of the featureWidth attribute + * @return a new instance of XlaSparseCoreAdagrad + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseCoreAdagrad create(Scope scope, Operand indices, + Operand gradient, Operand learningRate, Operand accumulator, + Operand embeddingTable, Long featureWidth) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseCoreAdagrad"); + opBuilder.addInput(indices.asOutput()); + opBuilder.addInput(gradient.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(accumulator.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.setAttr("feature_width", featureWidth); + return new XlaSparseCoreAdagrad(opBuilder.build()); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedAccumulator. + * + * @return updatedAccumulator. + */ + public Output updatedAccumulator() { + return updatedAccumulator; + } + + @OpInputsMetadata( + outputsClass = XlaSparseCoreAdagrad.class + ) + public static class Inputs extends RawOpInputs { + /** + * The indices input + */ + public final Operand indices; + + /** + * The gradient input + */ + public final Operand gradient; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The accumulator input + */ + public final Operand accumulator; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + public Inputs(GraphOperation op) { + super(new XlaSparseCoreAdagrad(op), op, Arrays.asList("feature_width")); + int inputIndex = 0; + indices = (Operand) op.input(inputIndex++); + gradient = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + accumulator = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + featureWidth = op.attributes().getAttrInt("feature_width"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdagradMomentum.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdagradMomentum.java new file mode 100644 index 00000000000..d65be317989 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdagradMomentum.java @@ -0,0 +1,216 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseCoreAdagradMomentum operation + */ +@OpMetadata( + opType = XlaSparseCoreAdagradMomentum.OP_NAME, + inputsClass = XlaSparseCoreAdagradMomentum.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseCoreAdagradMomentum extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseCoreAdagradMomentum"; + + private Output updatedEmbeddingTable; + + private Output updatedAccumulator; + + private Output updatedMomentum; + + public XlaSparseCoreAdagradMomentum(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedAccumulator = operation.output(outputIdx++); + updatedMomentum = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseCoreAdagradMomentum operation. + * + * @param scope current scope + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param beta1 The beta1 value + * @param epsilon The epsilon value + * @param accumulator The accumulator value + * @param momentum The momentum value + * @param embeddingTable The embeddingTable value + * @param featureWidth The value of the featureWidth attribute + * @param useNesterov The value of the useNesterov attribute + * @param beta2 The value of the beta2 attribute + * @param exponent The value of the exponent attribute + * @return a new instance of XlaSparseCoreAdagradMomentum + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseCoreAdagradMomentum create(Scope scope, Operand indices, + Operand gradient, Operand learningRate, Operand beta1, + Operand epsilon, Operand accumulator, Operand momentum, + Operand embeddingTable, Long featureWidth, Boolean useNesterov, Float beta2, + Float exponent) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseCoreAdagradMomentum"); + opBuilder.addInput(indices.asOutput()); + opBuilder.addInput(gradient.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(beta1.asOutput()); + opBuilder.addInput(epsilon.asOutput()); + opBuilder.addInput(accumulator.asOutput()); + opBuilder.addInput(momentum.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.setAttr("feature_width", featureWidth); + opBuilder.setAttr("use_nesterov", useNesterov); + opBuilder.setAttr("beta_2", beta2); + opBuilder.setAttr("exponent", exponent); + return new XlaSparseCoreAdagradMomentum(opBuilder.build()); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedAccumulator. + * + * @return updatedAccumulator. + */ + public Output updatedAccumulator() { + return updatedAccumulator; + } + + /** + * Gets updatedMomentum. + * + * @return updatedMomentum. + */ + public Output updatedMomentum() { + return updatedMomentum; + } + + @OpInputsMetadata( + outputsClass = XlaSparseCoreAdagradMomentum.class + ) + public static class Inputs extends RawOpInputs { + /** + * The indices input + */ + public final Operand indices; + + /** + * The gradient input + */ + public final Operand gradient; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The beta1 input + */ + public final Operand beta1; + + /** + * The epsilon input + */ + public final Operand epsilon; + + /** + * The accumulator input + */ + public final Operand accumulator; + + /** + * The momentum input + */ + public final Operand momentum; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + /** + * The useNesterov attribute + */ + public final boolean useNesterov; + + /** + * The beta2 attribute + */ + public final float beta2; + + /** + * The exponent attribute + */ + public final float exponent; + + public Inputs(GraphOperation op) { + super(new XlaSparseCoreAdagradMomentum(op), op, Arrays.asList("feature_width", "use_nesterov", "beta_2", "exponent")); + int inputIndex = 0; + indices = (Operand) op.input(inputIndex++); + gradient = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + beta1 = (Operand) op.input(inputIndex++); + epsilon = (Operand) op.input(inputIndex++); + accumulator = (Operand) op.input(inputIndex++); + momentum = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + featureWidth = op.attributes().getAttrInt("feature_width"); + useNesterov = op.attributes().getAttrBool("use_nesterov"); + beta2 = op.attributes().getAttrFloat("beta_2"); + exponent = op.attributes().getAttrFloat("exponent"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdam.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdam.java new file mode 100644 index 00000000000..8ec32200fad --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreAdam.java @@ -0,0 +1,208 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseCoreAdam operation + */ +@OpMetadata( + opType = XlaSparseCoreAdam.OP_NAME, + inputsClass = XlaSparseCoreAdam.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseCoreAdam extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseCoreAdam"; + + private Output updatedEmbeddingTable; + + private Output updatedVelocity; + + private Output updatedMomentum; + + public XlaSparseCoreAdam(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedVelocity = operation.output(outputIdx++); + updatedMomentum = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseCoreAdam operation. + * + * @param scope current scope + * @param embeddingTable The embeddingTable value + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param momentum The momentum value + * @param velocity The velocity value + * @param beta1 The beta1 value + * @param beta2 The beta2 value + * @param epsilon The epsilon value + * @param featureWidth The value of the featureWidth attribute + * @param useSumInsideSqrt The value of the useSumInsideSqrt attribute + * @return a new instance of XlaSparseCoreAdam + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseCoreAdam create(Scope scope, Operand embeddingTable, + Operand indices, Operand gradient, Operand learningRate, + Operand momentum, Operand velocity, Operand beta1, + Operand beta2, Operand epsilon, Long featureWidth, + Boolean useSumInsideSqrt) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseCoreAdam"); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(indices.asOutput()); + opBuilder.addInput(gradient.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(momentum.asOutput()); + opBuilder.addInput(velocity.asOutput()); + opBuilder.addInput(beta1.asOutput()); + opBuilder.addInput(beta2.asOutput()); + opBuilder.addInput(epsilon.asOutput()); + opBuilder.setAttr("feature_width", featureWidth); + opBuilder.setAttr("use_sum_inside_sqrt", useSumInsideSqrt); + return new XlaSparseCoreAdam(opBuilder.build()); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedVelocity. + * + * @return updatedVelocity. + */ + public Output updatedVelocity() { + return updatedVelocity; + } + + /** + * Gets updatedMomentum. + * + * @return updatedMomentum. + */ + public Output updatedMomentum() { + return updatedMomentum; + } + + @OpInputsMetadata( + outputsClass = XlaSparseCoreAdam.class + ) + public static class Inputs extends RawOpInputs { + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The indices input + */ + public final Operand indices; + + /** + * The gradient input + */ + public final Operand gradient; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The momentum input + */ + public final Operand momentum; + + /** + * The velocity input + */ + public final Operand velocity; + + /** + * The beta1 input + */ + public final Operand beta1; + + /** + * The beta2 input + */ + public final Operand beta2; + + /** + * The epsilon input + */ + public final Operand epsilon; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + /** + * The useSumInsideSqrt attribute + */ + public final boolean useSumInsideSqrt; + + public Inputs(GraphOperation op) { + super(new XlaSparseCoreAdam(op), op, Arrays.asList("feature_width", "use_sum_inside_sqrt")); + int inputIndex = 0; + embeddingTable = (Operand) op.input(inputIndex++); + indices = (Operand) op.input(inputIndex++); + gradient = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + momentum = (Operand) op.input(inputIndex++); + velocity = (Operand) op.input(inputIndex++); + beta1 = (Operand) op.input(inputIndex++); + beta2 = (Operand) op.input(inputIndex++); + epsilon = (Operand) op.input(inputIndex++); + featureWidth = op.attributes().getAttrInt("feature_width"); + useSumInsideSqrt = op.attributes().getAttrBool("use_sum_inside_sqrt"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreFtrl.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreFtrl.java new file mode 100644 index 00000000000..5ab961123b4 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreFtrl.java @@ -0,0 +1,216 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseCoreFtrl operation + */ +@OpMetadata( + opType = XlaSparseCoreFtrl.OP_NAME, + inputsClass = XlaSparseCoreFtrl.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseCoreFtrl extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseCoreFtrl"; + + private Output updatedEmbeddingTable; + + private Output updatedAccumulator; + + private Output updatedLinear; + + public XlaSparseCoreFtrl(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedAccumulator = operation.output(outputIdx++); + updatedLinear = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseCoreFtrl operation. + * + * @param scope current scope + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param linear The linear value + * @param learningRate The learningRate value + * @param indices The indices value + * @param gradient The gradient value + * @param beta The beta value + * @param learningRatePower The learningRatePower value + * @param l2RegularizationStrength The l2RegularizationStrength value + * @param featureWidth The value of the featureWidth attribute + * @param multiplyLinearByLearningRate The value of the multiplyLinearByLearningRate attribute + * @param l1RegularizationStrength The value of the l1RegularizationStrength attribute + * @return a new instance of XlaSparseCoreFtrl + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseCoreFtrl create(Scope scope, Operand embeddingTable, + Operand accumulator, Operand linear, Operand learningRate, + Operand indices, Operand gradient, Operand beta, + Operand learningRatePower, Operand l2RegularizationStrength, + Long featureWidth, Boolean multiplyLinearByLearningRate, Float l1RegularizationStrength) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseCoreFtrl"); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(accumulator.asOutput()); + opBuilder.addInput(linear.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(indices.asOutput()); + opBuilder.addInput(gradient.asOutput()); + opBuilder.addInput(beta.asOutput()); + opBuilder.addInput(learningRatePower.asOutput()); + opBuilder.addInput(l2RegularizationStrength.asOutput()); + opBuilder.setAttr("feature_width", featureWidth); + opBuilder.setAttr("multiply_linear_by_learning_rate", multiplyLinearByLearningRate); + opBuilder.setAttr("l1_regularization_strength", l1RegularizationStrength); + return new XlaSparseCoreFtrl(opBuilder.build()); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedAccumulator. + * + * @return updatedAccumulator. + */ + public Output updatedAccumulator() { + return updatedAccumulator; + } + + /** + * Gets updatedLinear. + * + * @return updatedLinear. + */ + public Output updatedLinear() { + return updatedLinear; + } + + @OpInputsMetadata( + outputsClass = XlaSparseCoreFtrl.class + ) + public static class Inputs extends RawOpInputs { + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The accumulator input + */ + public final Operand accumulator; + + /** + * The linear input + */ + public final Operand linear; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The indices input + */ + public final Operand indices; + + /** + * The gradient input + */ + public final Operand gradient; + + /** + * The beta input + */ + public final Operand beta; + + /** + * The learningRatePower input + */ + public final Operand learningRatePower; + + /** + * The l2RegularizationStrength input + */ + public final Operand l2RegularizationStrength; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + /** + * The multiplyLinearByLearningRate attribute + */ + public final boolean multiplyLinearByLearningRate; + + /** + * The l1RegularizationStrength attribute + */ + public final float l1RegularizationStrength; + + public Inputs(GraphOperation op) { + super(new XlaSparseCoreFtrl(op), op, Arrays.asList("feature_width", "multiply_linear_by_learning_rate", "l1_regularization_strength")); + int inputIndex = 0; + embeddingTable = (Operand) op.input(inputIndex++); + accumulator = (Operand) op.input(inputIndex++); + linear = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + indices = (Operand) op.input(inputIndex++); + gradient = (Operand) op.input(inputIndex++); + beta = (Operand) op.input(inputIndex++); + learningRatePower = (Operand) op.input(inputIndex++); + l2RegularizationStrength = (Operand) op.input(inputIndex++); + featureWidth = op.attributes().getAttrInt("feature_width"); + multiplyLinearByLearningRate = op.attributes().getAttrBool("multiply_linear_by_learning_rate"); + l1RegularizationStrength = op.attributes().getAttrFloat("l1_regularization_strength"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreSgd.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreSgd.java new file mode 100644 index 00000000000..70830868226 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseCoreSgd.java @@ -0,0 +1,139 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseCoreSgd operation + */ +@OpMetadata( + opType = XlaSparseCoreSgd.OP_NAME, + inputsClass = XlaSparseCoreSgd.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseCoreSgd extends RawOp implements Operand { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseCoreSgd"; + + private Output updatedEmbeddingTable; + + public XlaSparseCoreSgd(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseCoreSgd operation. + * + * @param scope current scope + * @param indices The indices value + * @param gradient The gradient value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param featureWidth The value of the featureWidth attribute + * @return a new instance of XlaSparseCoreSgd + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseCoreSgd create(Scope scope, Operand indices, + Operand gradient, Operand learningRate, Operand embeddingTable, + Long featureWidth) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseCoreSgd"); + opBuilder.addInput(indices.asOutput()); + opBuilder.addInput(gradient.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.setAttr("feature_width", featureWidth); + return new XlaSparseCoreSgd(opBuilder.build()); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + @Override + public Output asOutput() { + return updatedEmbeddingTable; + } + + @OpInputsMetadata( + outputsClass = XlaSparseCoreSgd.class + ) + public static class Inputs extends RawOpInputs { + /** + * The indices input + */ + public final Operand indices; + + /** + * The gradient input + */ + public final Operand gradient; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The featureWidth attribute + */ + public final long featureWidth; + + public Inputs(GraphOperation op) { + super(new XlaSparseCoreSgd(op), op, Arrays.asList("feature_width")); + int inputIndex = 0; + indices = (Operand) op.input(inputIndex++); + gradient = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + featureWidth = op.attributes().getAttrInt("feature_width"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmul.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmul.java new file mode 100644 index 00000000000..0d3fab06b13 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmul.java @@ -0,0 +1,208 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; +import org.tensorflow.types.family.TType; + +/** + * The XlaSparseDenseMatmul operation + */ +@OpMetadata( + opType = XlaSparseDenseMatmul.OP_NAME, + inputsClass = XlaSparseDenseMatmul.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseDenseMatmul extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseDenseMatmul"; + + private Output activations; + + private Output rowPointers; + + private Output sortedEmbeddingIds; + + private Output sortedSampleIds; + + private Output sortedGains; + + public XlaSparseDenseMatmul(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + activations = operation.output(outputIdx++); + rowPointers = operation.output(outputIdx++); + sortedEmbeddingIds = operation.output(outputIdx++); + sortedSampleIds = operation.output(outputIdx++); + sortedGains = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseDenseMatmul operation. + * + * @param scope current scope + * @param rowIds The rowIds value + * @param colIds The colIds value + * @param values The values value + * @param offsets The offsets value + * @param embeddingTable The embeddingTable value + * @param maxIdsPerPartition The value of the maxIdsPerPartition attribute + * @param maxUniqueIdsPerPartition The value of the maxUniqueIdsPerPartition attribute + * @param inputSize The value of the inputSize attribute + * @return a new instance of XlaSparseDenseMatmul + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseDenseMatmul create(Scope scope, Operand rowIds, + Operand colIds, Operand values, Operand offsets, + Operand embeddingTable, Long maxIdsPerPartition, Long maxUniqueIdsPerPartition, + Long inputSize) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseDenseMatmul"); + opBuilder.addInput(rowIds.asOutput()); + opBuilder.addInput(colIds.asOutput()); + opBuilder.addInput(values.asOutput()); + opBuilder.addInput(offsets.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.setAttr("max_ids_per_partition", maxIdsPerPartition); + opBuilder.setAttr("max_unique_ids_per_partition", maxUniqueIdsPerPartition); + opBuilder.setAttr("input_size", inputSize); + return new XlaSparseDenseMatmul(opBuilder.build()); + } + + /** + * Gets activations. + * + * @return activations. + */ + public Output activations() { + return activations; + } + + /** + * Gets rowPointers. + * + * @return rowPointers. + */ + public Output rowPointers() { + return rowPointers; + } + + /** + * Gets sortedEmbeddingIds. + * + * @return sortedEmbeddingIds. + */ + public Output sortedEmbeddingIds() { + return sortedEmbeddingIds; + } + + /** + * Gets sortedSampleIds. + * + * @return sortedSampleIds. + */ + public Output sortedSampleIds() { + return sortedSampleIds; + } + + /** + * Gets sortedGains. + * + * @return sortedGains. + */ + public Output sortedGains() { + return sortedGains; + } + + @OpInputsMetadata( + outputsClass = XlaSparseDenseMatmul.class + ) + public static class Inputs extends RawOpInputs { + /** + * The rowIds input + */ + public final Operand rowIds; + + /** + * The colIds input + */ + public final Operand colIds; + + /** + * The values input + */ + public final Operand values; + + /** + * The offsets input + */ + public final Operand offsets; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The maxIdsPerPartition attribute + */ + public final long maxIdsPerPartition; + + /** + * The maxUniqueIdsPerPartition attribute + */ + public final long maxUniqueIdsPerPartition; + + /** + * The inputSize attribute + */ + public final long inputSize; + + public Inputs(GraphOperation op) { + super(new XlaSparseDenseMatmul(op), op, Arrays.asList("max_ids_per_partition", "max_unique_ids_per_partition", "input_size")); + int inputIndex = 0; + rowIds = (Operand) op.input(inputIndex++); + colIds = (Operand) op.input(inputIndex++); + values = (Operand) op.input(inputIndex++); + offsets = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + maxIdsPerPartition = op.attributes().getAttrInt("max_ids_per_partition"); + maxUniqueIdsPerPartition = op.attributes().getAttrInt("max_unique_ids_per_partition"); + inputSize = op.attributes().getAttrInt("input_size"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.java new file mode 100644 index 00000000000..b63cba97719 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdagradAndCsrInput.java @@ -0,0 +1,266 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseDenseMatmulGradWithAdagradAndCsrInput operation + */ +@OpMetadata( + opType = XlaSparseDenseMatmulGradWithAdagradAndCsrInput.OP_NAME, + inputsClass = XlaSparseDenseMatmulGradWithAdagradAndCsrInput.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseDenseMatmulGradWithAdagradAndCsrInput extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"; + + private Output updatedEmbeddingTable; + + private Output updatedAccumulator; + + public XlaSparseDenseMatmulGradWithAdagradAndCsrInput(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedAccumulator = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseDenseMatmulGradWithAdagradAndCsrInput operation. + * + * @param scope current scope + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithAdagradAndCsrInput + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseDenseMatmulGradWithAdagradAndCsrInput create(Scope scope, + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand accumulator, Operand numMinibatchesPerPhysicalSparseCore, + String tableName, Options... options) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseDenseMatmulGradWithAdagradAndCsrInput"); + opBuilder.addInput(rowPointers.asOutput()); + opBuilder.addInput(sortedSampleIds.asOutput()); + opBuilder.addInput(sortedTokenIds.asOutput()); + opBuilder.addInput(sortedGains.asOutput()); + opBuilder.addInput(activationGradients.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(accumulator.asOutput()); + opBuilder.addInput(numMinibatchesPerPhysicalSparseCore.asOutput()); + opBuilder.setAttr("table_name", tableName); + if (options != null) { + for (Options opts : options) { + if (opts.clipWeightMin != null) { + opBuilder.setAttr("clip_weight_min", opts.clipWeightMin); + } + if (opts.clipWeightMax != null) { + opBuilder.setAttr("clip_weight_max", opts.clipWeightMax); + } + } + } + return new XlaSparseDenseMatmulGradWithAdagradAndCsrInput(opBuilder.build()); + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public static Options clipWeightMin(Float clipWeightMin) { + return new Options().clipWeightMin(clipWeightMin); + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public static Options clipWeightMax(Float clipWeightMax) { + return new Options().clipWeightMax(clipWeightMax); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedAccumulator. + * + * @return updatedAccumulator. + */ + public Output updatedAccumulator() { + return updatedAccumulator; + } + + /** + * Optional attributes for {@link org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithAdagradAndCsrInput} + */ + public static class Options { + private Float clipWeightMin; + + private Float clipWeightMax; + + private Options() { + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public Options clipWeightMin(Float clipWeightMin) { + this.clipWeightMin = clipWeightMin; + return this; + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public Options clipWeightMax(Float clipWeightMax) { + this.clipWeightMax = clipWeightMax; + return this; + } + } + + @OpInputsMetadata( + outputsClass = XlaSparseDenseMatmulGradWithAdagradAndCsrInput.class + ) + public static class Inputs extends RawOpInputs { + /** + * The rowPointers input + */ + public final Operand rowPointers; + + /** + * The sortedSampleIds input + */ + public final Operand sortedSampleIds; + + /** + * The sortedTokenIds input + */ + public final Operand sortedTokenIds; + + /** + * The sortedGains input + */ + public final Operand sortedGains; + + /** + * The activationGradients input + */ + public final Operand activationGradients; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The accumulator input + */ + public final Operand accumulator; + + /** + * The numMinibatchesPerPhysicalSparseCore input + */ + public final Operand numMinibatchesPerPhysicalSparseCore; + + /** + * The clipWeightMin attribute + */ + public final float clipWeightMin; + + /** + * The clipWeightMax attribute + */ + public final float clipWeightMax; + + /** + * The tableName attribute + */ + public final String tableName; + + public Inputs(GraphOperation op) { + super(new XlaSparseDenseMatmulGradWithAdagradAndCsrInput(op), op, Arrays.asList("clip_weight_min", "clip_weight_max", "table_name")); + int inputIndex = 0; + rowPointers = (Operand) op.input(inputIndex++); + sortedSampleIds = (Operand) op.input(inputIndex++); + sortedTokenIds = (Operand) op.input(inputIndex++); + sortedGains = (Operand) op.input(inputIndex++); + activationGradients = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + accumulator = (Operand) op.input(inputIndex++); + numMinibatchesPerPhysicalSparseCore = (Operand) op.input(inputIndex++); + clipWeightMin = op.attributes().getAttrFloat("clip_weight_min"); + clipWeightMax = op.attributes().getAttrFloat("clip_weight_max"); + tableName = op.attributes().getAttrString("table_name"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.java new file mode 100644 index 00000000000..faa117af196 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.java @@ -0,0 +1,327 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput operation + */ +@OpMetadata( + opType = XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.OP_NAME, + inputsClass = XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"; + + private Output updatedEmbeddingTable; + + private Output updatedAccumulator; + + private Output updatedMomenta; + + public XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedAccumulator = operation.output(outputIdx++); + updatedMomenta = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput operation. + * + * @param scope current scope + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param momenta The momenta value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param useNesterov The value of the useNesterov attribute + * @param exponent The value of the exponent attribute + * @param beta1 The value of the beta1 attribute + * @param beta2 The value of the beta2 attribute + * @param epsilon The value of the epsilon attribute + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput create(Scope scope, + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand accumulator, Operand momenta, + Operand numMinibatchesPerPhysicalSparseCore, Boolean useNesterov, Float exponent, + Float beta1, Float beta2, Float epsilon, String tableName, Options... options) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput"); + opBuilder.addInput(rowPointers.asOutput()); + opBuilder.addInput(sortedSampleIds.asOutput()); + opBuilder.addInput(sortedTokenIds.asOutput()); + opBuilder.addInput(sortedGains.asOutput()); + opBuilder.addInput(activationGradients.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(accumulator.asOutput()); + opBuilder.addInput(momenta.asOutput()); + opBuilder.addInput(numMinibatchesPerPhysicalSparseCore.asOutput()); + opBuilder.setAttr("use_nesterov", useNesterov); + opBuilder.setAttr("exponent", exponent); + opBuilder.setAttr("beta1", beta1); + opBuilder.setAttr("beta2", beta2); + opBuilder.setAttr("epsilon", epsilon); + opBuilder.setAttr("table_name", tableName); + if (options != null) { + for (Options opts : options) { + if (opts.clipWeightMin != null) { + opBuilder.setAttr("clip_weight_min", opts.clipWeightMin); + } + if (opts.clipWeightMax != null) { + opBuilder.setAttr("clip_weight_max", opts.clipWeightMax); + } + } + } + return new XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput(opBuilder.build()); + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public static Options clipWeightMin(Float clipWeightMin) { + return new Options().clipWeightMin(clipWeightMin); + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public static Options clipWeightMax(Float clipWeightMax) { + return new Options().clipWeightMax(clipWeightMax); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedAccumulator. + * + * @return updatedAccumulator. + */ + public Output updatedAccumulator() { + return updatedAccumulator; + } + + /** + * Gets updatedMomenta. + * + * @return updatedMomenta. + */ + public Output updatedMomenta() { + return updatedMomenta; + } + + /** + * Optional attributes for {@link org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput} + */ + public static class Options { + private Float clipWeightMin; + + private Float clipWeightMax; + + private Options() { + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public Options clipWeightMin(Float clipWeightMin) { + this.clipWeightMin = clipWeightMin; + return this; + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public Options clipWeightMax(Float clipWeightMax) { + this.clipWeightMax = clipWeightMax; + return this; + } + } + + @OpInputsMetadata( + outputsClass = XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput.class + ) + public static class Inputs extends RawOpInputs { + /** + * The rowPointers input + */ + public final Operand rowPointers; + + /** + * The sortedSampleIds input + */ + public final Operand sortedSampleIds; + + /** + * The sortedTokenIds input + */ + public final Operand sortedTokenIds; + + /** + * The sortedGains input + */ + public final Operand sortedGains; + + /** + * The activationGradients input + */ + public final Operand activationGradients; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The accumulator input + */ + public final Operand accumulator; + + /** + * The momenta input + */ + public final Operand momenta; + + /** + * The numMinibatchesPerPhysicalSparseCore input + */ + public final Operand numMinibatchesPerPhysicalSparseCore; + + /** + * The useNesterov attribute + */ + public final boolean useNesterov; + + /** + * The exponent attribute + */ + public final float exponent; + + /** + * The beta1 attribute + */ + public final float beta1; + + /** + * The beta2 attribute + */ + public final float beta2; + + /** + * The epsilon attribute + */ + public final float epsilon; + + /** + * The clipWeightMin attribute + */ + public final float clipWeightMin; + + /** + * The clipWeightMax attribute + */ + public final float clipWeightMax; + + /** + * The tableName attribute + */ + public final String tableName; + + public Inputs(GraphOperation op) { + super(new XlaSparseDenseMatmulGradWithAdagradMomentumAndCsrInput(op), op, Arrays.asList("use_nesterov", "exponent", "beta1", "beta2", "epsilon", "clip_weight_min", "clip_weight_max", "table_name")); + int inputIndex = 0; + rowPointers = (Operand) op.input(inputIndex++); + sortedSampleIds = (Operand) op.input(inputIndex++); + sortedTokenIds = (Operand) op.input(inputIndex++); + sortedGains = (Operand) op.input(inputIndex++); + activationGradients = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + accumulator = (Operand) op.input(inputIndex++); + momenta = (Operand) op.input(inputIndex++); + numMinibatchesPerPhysicalSparseCore = (Operand) op.input(inputIndex++); + useNesterov = op.attributes().getAttrBool("use_nesterov"); + exponent = op.attributes().getAttrFloat("exponent"); + beta1 = op.attributes().getAttrFloat("beta1"); + beta2 = op.attributes().getAttrFloat("beta2"); + epsilon = op.attributes().getAttrFloat("epsilon"); + clipWeightMin = op.attributes().getAttrFloat("clip_weight_min"); + clipWeightMax = op.attributes().getAttrFloat("clip_weight_max"); + tableName = op.attributes().getAttrString("table_name"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdamAndCsrInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdamAndCsrInput.java new file mode 100644 index 00000000000..a39190054b3 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithAdamAndCsrInput.java @@ -0,0 +1,319 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseDenseMatmulGradWithAdamAndCsrInput operation + */ +@OpMetadata( + opType = XlaSparseDenseMatmulGradWithAdamAndCsrInput.OP_NAME, + inputsClass = XlaSparseDenseMatmulGradWithAdamAndCsrInput.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseDenseMatmulGradWithAdamAndCsrInput extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseDenseMatmulGradWithAdamAndCsrInput"; + + private Output updatedEmbeddingTable; + + private Output updatedMomenta; + + private Output updatedVelocity; + + public XlaSparseDenseMatmulGradWithAdamAndCsrInput(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedMomenta = operation.output(outputIdx++); + updatedVelocity = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseDenseMatmulGradWithAdamAndCsrInput operation. + * + * @param scope current scope + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param momenta The momenta value + * @param velocity The velocity value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param useSumInsideSqrt The value of the useSumInsideSqrt attribute + * @param beta1 The value of the beta1 attribute + * @param beta2 The value of the beta2 attribute + * @param epsilon The value of the epsilon attribute + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithAdamAndCsrInput + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseDenseMatmulGradWithAdamAndCsrInput create(Scope scope, + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, Operand momenta, + Operand velocity, Operand numMinibatchesPerPhysicalSparseCore, + Boolean useSumInsideSqrt, Float beta1, Float beta2, Float epsilon, String tableName, + Options... options) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseDenseMatmulGradWithAdamAndCsrInput"); + opBuilder.addInput(rowPointers.asOutput()); + opBuilder.addInput(sortedSampleIds.asOutput()); + opBuilder.addInput(sortedTokenIds.asOutput()); + opBuilder.addInput(sortedGains.asOutput()); + opBuilder.addInput(activationGradients.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(momenta.asOutput()); + opBuilder.addInput(velocity.asOutput()); + opBuilder.addInput(numMinibatchesPerPhysicalSparseCore.asOutput()); + opBuilder.setAttr("use_sum_inside_sqrt", useSumInsideSqrt); + opBuilder.setAttr("beta1", beta1); + opBuilder.setAttr("beta2", beta2); + opBuilder.setAttr("epsilon", epsilon); + opBuilder.setAttr("table_name", tableName); + if (options != null) { + for (Options opts : options) { + if (opts.clipWeightMin != null) { + opBuilder.setAttr("clip_weight_min", opts.clipWeightMin); + } + if (opts.clipWeightMax != null) { + opBuilder.setAttr("clip_weight_max", opts.clipWeightMax); + } + } + } + return new XlaSparseDenseMatmulGradWithAdamAndCsrInput(opBuilder.build()); + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public static Options clipWeightMin(Float clipWeightMin) { + return new Options().clipWeightMin(clipWeightMin); + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public static Options clipWeightMax(Float clipWeightMax) { + return new Options().clipWeightMax(clipWeightMax); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedMomenta. + * + * @return updatedMomenta. + */ + public Output updatedMomenta() { + return updatedMomenta; + } + + /** + * Gets updatedVelocity. + * + * @return updatedVelocity. + */ + public Output updatedVelocity() { + return updatedVelocity; + } + + /** + * Optional attributes for {@link org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithAdamAndCsrInput} + */ + public static class Options { + private Float clipWeightMin; + + private Float clipWeightMax; + + private Options() { + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public Options clipWeightMin(Float clipWeightMin) { + this.clipWeightMin = clipWeightMin; + return this; + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public Options clipWeightMax(Float clipWeightMax) { + this.clipWeightMax = clipWeightMax; + return this; + } + } + + @OpInputsMetadata( + outputsClass = XlaSparseDenseMatmulGradWithAdamAndCsrInput.class + ) + public static class Inputs extends RawOpInputs { + /** + * The rowPointers input + */ + public final Operand rowPointers; + + /** + * The sortedSampleIds input + */ + public final Operand sortedSampleIds; + + /** + * The sortedTokenIds input + */ + public final Operand sortedTokenIds; + + /** + * The sortedGains input + */ + public final Operand sortedGains; + + /** + * The activationGradients input + */ + public final Operand activationGradients; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The momenta input + */ + public final Operand momenta; + + /** + * The velocity input + */ + public final Operand velocity; + + /** + * The numMinibatchesPerPhysicalSparseCore input + */ + public final Operand numMinibatchesPerPhysicalSparseCore; + + /** + * The useSumInsideSqrt attribute + */ + public final boolean useSumInsideSqrt; + + /** + * The beta1 attribute + */ + public final float beta1; + + /** + * The beta2 attribute + */ + public final float beta2; + + /** + * The epsilon attribute + */ + public final float epsilon; + + /** + * The clipWeightMin attribute + */ + public final float clipWeightMin; + + /** + * The clipWeightMax attribute + */ + public final float clipWeightMax; + + /** + * The tableName attribute + */ + public final String tableName; + + public Inputs(GraphOperation op) { + super(new XlaSparseDenseMatmulGradWithAdamAndCsrInput(op), op, Arrays.asList("use_sum_inside_sqrt", "beta1", "beta2", "epsilon", "clip_weight_min", "clip_weight_max", "table_name")); + int inputIndex = 0; + rowPointers = (Operand) op.input(inputIndex++); + sortedSampleIds = (Operand) op.input(inputIndex++); + sortedTokenIds = (Operand) op.input(inputIndex++); + sortedGains = (Operand) op.input(inputIndex++); + activationGradients = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + momenta = (Operand) op.input(inputIndex++); + velocity = (Operand) op.input(inputIndex++); + numMinibatchesPerPhysicalSparseCore = (Operand) op.input(inputIndex++); + useSumInsideSqrt = op.attributes().getAttrBool("use_sum_inside_sqrt"); + beta1 = op.attributes().getAttrFloat("beta1"); + beta2 = op.attributes().getAttrFloat("beta2"); + epsilon = op.attributes().getAttrFloat("epsilon"); + clipWeightMin = op.attributes().getAttrFloat("clip_weight_min"); + clipWeightMax = op.attributes().getAttrFloat("clip_weight_max"); + tableName = op.attributes().getAttrString("table_name"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.java new file mode 100644 index 00000000000..b8d005b7059 --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithFtrlAndCsrInput.java @@ -0,0 +1,328 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseDenseMatmulGradWithFtrlAndCsrInput operation + */ +@OpMetadata( + opType = XlaSparseDenseMatmulGradWithFtrlAndCsrInput.OP_NAME, + inputsClass = XlaSparseDenseMatmulGradWithFtrlAndCsrInput.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseDenseMatmulGradWithFtrlAndCsrInput extends RawOp { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"; + + private Output updatedEmbeddingTable; + + private Output updatedAccumulator; + + private Output updatedLinear; + + public XlaSparseDenseMatmulGradWithFtrlAndCsrInput(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + updatedAccumulator = operation.output(outputIdx++); + updatedLinear = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseDenseMatmulGradWithFtrlAndCsrInput operation. + * + * @param scope current scope + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param accumulator The accumulator value + * @param linear The linear value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param multiplyLinearByLearningRate The value of the multiplyLinearByLearningRate attribute + * @param beta The value of the beta attribute + * @param learningRatePower The value of the learningRatePower attribute + * @param l1RegularizationStrength The value of the l1RegularizationStrength attribute + * @param l2RegularizationStrength The value of the l2RegularizationStrength attribute + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithFtrlAndCsrInput + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseDenseMatmulGradWithFtrlAndCsrInput create(Scope scope, + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand accumulator, Operand linear, + Operand numMinibatchesPerPhysicalSparseCore, Boolean multiplyLinearByLearningRate, + Float beta, Float learningRatePower, Float l1RegularizationStrength, + Float l2RegularizationStrength, String tableName, Options... options) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseDenseMatmulGradWithFtrlAndCsrInput"); + opBuilder.addInput(rowPointers.asOutput()); + opBuilder.addInput(sortedSampleIds.asOutput()); + opBuilder.addInput(sortedTokenIds.asOutput()); + opBuilder.addInput(sortedGains.asOutput()); + opBuilder.addInput(activationGradients.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(accumulator.asOutput()); + opBuilder.addInput(linear.asOutput()); + opBuilder.addInput(numMinibatchesPerPhysicalSparseCore.asOutput()); + opBuilder.setAttr("multiply_linear_by_learning_rate", multiplyLinearByLearningRate); + opBuilder.setAttr("beta", beta); + opBuilder.setAttr("learning_rate_power", learningRatePower); + opBuilder.setAttr("l1_regularization_strength", l1RegularizationStrength); + opBuilder.setAttr("l2_regularization_strength", l2RegularizationStrength); + opBuilder.setAttr("table_name", tableName); + if (options != null) { + for (Options opts : options) { + if (opts.clipWeightMin != null) { + opBuilder.setAttr("clip_weight_min", opts.clipWeightMin); + } + if (opts.clipWeightMax != null) { + opBuilder.setAttr("clip_weight_max", opts.clipWeightMax); + } + } + } + return new XlaSparseDenseMatmulGradWithFtrlAndCsrInput(opBuilder.build()); + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public static Options clipWeightMin(Float clipWeightMin) { + return new Options().clipWeightMin(clipWeightMin); + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public static Options clipWeightMax(Float clipWeightMax) { + return new Options().clipWeightMax(clipWeightMax); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + /** + * Gets updatedAccumulator. + * + * @return updatedAccumulator. + */ + public Output updatedAccumulator() { + return updatedAccumulator; + } + + /** + * Gets updatedLinear. + * + * @return updatedLinear. + */ + public Output updatedLinear() { + return updatedLinear; + } + + /** + * Optional attributes for {@link org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithFtrlAndCsrInput} + */ + public static class Options { + private Float clipWeightMin; + + private Float clipWeightMax; + + private Options() { + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public Options clipWeightMin(Float clipWeightMin) { + this.clipWeightMin = clipWeightMin; + return this; + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public Options clipWeightMax(Float clipWeightMax) { + this.clipWeightMax = clipWeightMax; + return this; + } + } + + @OpInputsMetadata( + outputsClass = XlaSparseDenseMatmulGradWithFtrlAndCsrInput.class + ) + public static class Inputs extends RawOpInputs { + /** + * The rowPointers input + */ + public final Operand rowPointers; + + /** + * The sortedSampleIds input + */ + public final Operand sortedSampleIds; + + /** + * The sortedTokenIds input + */ + public final Operand sortedTokenIds; + + /** + * The sortedGains input + */ + public final Operand sortedGains; + + /** + * The activationGradients input + */ + public final Operand activationGradients; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The accumulator input + */ + public final Operand accumulator; + + /** + * The linear input + */ + public final Operand linear; + + /** + * The numMinibatchesPerPhysicalSparseCore input + */ + public final Operand numMinibatchesPerPhysicalSparseCore; + + /** + * The multiplyLinearByLearningRate attribute + */ + public final boolean multiplyLinearByLearningRate; + + /** + * The beta attribute + */ + public final float beta; + + /** + * The learningRatePower attribute + */ + public final float learningRatePower; + + /** + * The l1RegularizationStrength attribute + */ + public final float l1RegularizationStrength; + + /** + * The l2RegularizationStrength attribute + */ + public final float l2RegularizationStrength; + + /** + * The clipWeightMin attribute + */ + public final float clipWeightMin; + + /** + * The clipWeightMax attribute + */ + public final float clipWeightMax; + + /** + * The tableName attribute + */ + public final String tableName; + + public Inputs(GraphOperation op) { + super(new XlaSparseDenseMatmulGradWithFtrlAndCsrInput(op), op, Arrays.asList("multiply_linear_by_learning_rate", "beta", "learning_rate_power", "l1_regularization_strength", "l2_regularization_strength", "clip_weight_min", "clip_weight_max", "table_name")); + int inputIndex = 0; + rowPointers = (Operand) op.input(inputIndex++); + sortedSampleIds = (Operand) op.input(inputIndex++); + sortedTokenIds = (Operand) op.input(inputIndex++); + sortedGains = (Operand) op.input(inputIndex++); + activationGradients = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + accumulator = (Operand) op.input(inputIndex++); + linear = (Operand) op.input(inputIndex++); + numMinibatchesPerPhysicalSparseCore = (Operand) op.input(inputIndex++); + multiplyLinearByLearningRate = op.attributes().getAttrBool("multiply_linear_by_learning_rate"); + beta = op.attributes().getAttrFloat("beta"); + learningRatePower = op.attributes().getAttrFloat("learning_rate_power"); + l1RegularizationStrength = op.attributes().getAttrFloat("l1_regularization_strength"); + l2RegularizationStrength = op.attributes().getAttrFloat("l2_regularization_strength"); + clipWeightMin = op.attributes().getAttrFloat("clip_weight_min"); + clipWeightMax = op.attributes().getAttrFloat("clip_weight_max"); + tableName = op.attributes().getAttrString("table_name"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithSgdAndCsrInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithSgdAndCsrInput.java new file mode 100644 index 00000000000..bfb8a3a127f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulGradWithSgdAndCsrInput.java @@ -0,0 +1,250 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseDenseMatmulGradWithSgdAndCsrInput operation + */ +@OpMetadata( + opType = XlaSparseDenseMatmulGradWithSgdAndCsrInput.OP_NAME, + inputsClass = XlaSparseDenseMatmulGradWithSgdAndCsrInput.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseDenseMatmulGradWithSgdAndCsrInput extends RawOp implements Operand { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseDenseMatmulGradWithSgdAndCsrInput"; + + private Output updatedEmbeddingTable; + + public XlaSparseDenseMatmulGradWithSgdAndCsrInput(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + updatedEmbeddingTable = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseDenseMatmulGradWithSgdAndCsrInput operation. + * + * @param scope current scope + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param activationGradients The activationGradients value + * @param learningRate The learningRate value + * @param embeddingTable The embeddingTable value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param tableName The value of the tableName attribute + * @param options carries optional attribute values + * @return a new instance of XlaSparseDenseMatmulGradWithSgdAndCsrInput + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseDenseMatmulGradWithSgdAndCsrInput create(Scope scope, + Operand rowPointers, Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand activationGradients, + Operand learningRate, Operand embeddingTable, + Operand numMinibatchesPerPhysicalSparseCore, String tableName, Options... options) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseDenseMatmulGradWithSgdAndCsrInput"); + opBuilder.addInput(rowPointers.asOutput()); + opBuilder.addInput(sortedSampleIds.asOutput()); + opBuilder.addInput(sortedTokenIds.asOutput()); + opBuilder.addInput(sortedGains.asOutput()); + opBuilder.addInput(activationGradients.asOutput()); + opBuilder.addInput(learningRate.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(numMinibatchesPerPhysicalSparseCore.asOutput()); + opBuilder.setAttr("table_name", tableName); + if (options != null) { + for (Options opts : options) { + if (opts.clipWeightMin != null) { + opBuilder.setAttr("clip_weight_min", opts.clipWeightMin); + } + if (opts.clipWeightMax != null) { + opBuilder.setAttr("clip_weight_max", opts.clipWeightMax); + } + } + } + return new XlaSparseDenseMatmulGradWithSgdAndCsrInput(opBuilder.build()); + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public static Options clipWeightMin(Float clipWeightMin) { + return new Options().clipWeightMin(clipWeightMin); + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public static Options clipWeightMax(Float clipWeightMax) { + return new Options().clipWeightMax(clipWeightMax); + } + + /** + * Gets updatedEmbeddingTable. + * + * @return updatedEmbeddingTable. + */ + public Output updatedEmbeddingTable() { + return updatedEmbeddingTable; + } + + @Override + public Output asOutput() { + return updatedEmbeddingTable; + } + + /** + * Optional attributes for {@link org.tensorflow.op.xla.XlaSparseDenseMatmulGradWithSgdAndCsrInput} + */ + public static class Options { + private Float clipWeightMin; + + private Float clipWeightMax; + + private Options() { + } + + /** + * Sets the clipWeightMin option. + * + * @param clipWeightMin the clipWeightMin option + * @return this Options instance. + */ + public Options clipWeightMin(Float clipWeightMin) { + this.clipWeightMin = clipWeightMin; + return this; + } + + /** + * Sets the clipWeightMax option. + * + * @param clipWeightMax the clipWeightMax option + * @return this Options instance. + */ + public Options clipWeightMax(Float clipWeightMax) { + this.clipWeightMax = clipWeightMax; + return this; + } + } + + @OpInputsMetadata( + outputsClass = XlaSparseDenseMatmulGradWithSgdAndCsrInput.class + ) + public static class Inputs extends RawOpInputs { + /** + * The rowPointers input + */ + public final Operand rowPointers; + + /** + * The sortedSampleIds input + */ + public final Operand sortedSampleIds; + + /** + * The sortedTokenIds input + */ + public final Operand sortedTokenIds; + + /** + * The sortedGains input + */ + public final Operand sortedGains; + + /** + * The activationGradients input + */ + public final Operand activationGradients; + + /** + * The learningRate input + */ + public final Operand learningRate; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The numMinibatchesPerPhysicalSparseCore input + */ + public final Operand numMinibatchesPerPhysicalSparseCore; + + /** + * The clipWeightMin attribute + */ + public final float clipWeightMin; + + /** + * The clipWeightMax attribute + */ + public final float clipWeightMax; + + /** + * The tableName attribute + */ + public final String tableName; + + public Inputs(GraphOperation op) { + super(new XlaSparseDenseMatmulGradWithSgdAndCsrInput(op), op, Arrays.asList("clip_weight_min", "clip_weight_max", "table_name")); + int inputIndex = 0; + rowPointers = (Operand) op.input(inputIndex++); + sortedSampleIds = (Operand) op.input(inputIndex++); + sortedTokenIds = (Operand) op.input(inputIndex++); + sortedGains = (Operand) op.input(inputIndex++); + activationGradients = (Operand) op.input(inputIndex++); + learningRate = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + numMinibatchesPerPhysicalSparseCore = (Operand) op.input(inputIndex++); + clipWeightMin = op.attributes().getAttrFloat("clip_weight_min"); + clipWeightMax = op.attributes().getAttrFloat("clip_weight_max"); + tableName = op.attributes().getAttrString("table_name"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulWithCsrInput.java b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulWithCsrInput.java new file mode 100644 index 00000000000..793c460676f --- /dev/null +++ b/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/xla/XlaSparseDenseMatmulWithCsrInput.java @@ -0,0 +1,190 @@ +/* Copyright 2018-2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ + +// This class has been generated, DO NOT EDIT! + +package org.tensorflow.op.xla; + +import java.util.Arrays; +import org.tensorflow.GraphOperation; +import org.tensorflow.Operand; +import org.tensorflow.Operation; +import org.tensorflow.OperationBuilder; +import org.tensorflow.Output; +import org.tensorflow.op.RawOp; +import org.tensorflow.op.RawOpInputs; +import org.tensorflow.op.Scope; +import org.tensorflow.op.annotation.Endpoint; +import org.tensorflow.op.annotation.OpInputsMetadata; +import org.tensorflow.op.annotation.OpMetadata; +import org.tensorflow.op.annotation.Operator; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TInt32; + +/** + * The XlaSparseDenseMatmulWithCsrInput operation + */ +@OpMetadata( + opType = XlaSparseDenseMatmulWithCsrInput.OP_NAME, + inputsClass = XlaSparseDenseMatmulWithCsrInput.Inputs.class +) +@Operator( + group = "xla" +) +public final class XlaSparseDenseMatmulWithCsrInput extends RawOp implements Operand { + /** + * The name of this op, as known by TensorFlow core engine + */ + public static final String OP_NAME = "XlaSparseDenseMatmulWithCsrInput"; + + private Output activations; + + public XlaSparseDenseMatmulWithCsrInput(Operation operation) { + super(operation, OP_NAME); + int outputIdx = 0; + activations = operation.output(outputIdx++); + } + + /** + * Factory method to create a class wrapping a new XlaSparseDenseMatmulWithCsrInput operation. + * + * @param scope current scope + * @param rowPointers The rowPointers value + * @param sortedSampleIds The sortedSampleIds value + * @param sortedTokenIds The sortedTokenIds value + * @param sortedGains The sortedGains value + * @param embeddingTable The embeddingTable value + * @param numMinibatchesPerPhysicalSparseCore The numMinibatchesPerPhysicalSparseCore value + * @param inputSize The value of the inputSize attribute + * @param quantizationConfigLow The value of the quantizationConfigLow attribute + * @param quantizationConfigHigh The value of the quantizationConfigHigh attribute + * @param quantizationConfigNumBuckets The value of the quantizationConfigNumBuckets attribute + * @param tableName The value of the tableName attribute + * @return a new instance of XlaSparseDenseMatmulWithCsrInput + */ + @Endpoint( + describeByClass = true + ) + public static XlaSparseDenseMatmulWithCsrInput create(Scope scope, Operand rowPointers, + Operand sortedSampleIds, Operand sortedTokenIds, + Operand sortedGains, Operand embeddingTable, + Operand numMinibatchesPerPhysicalSparseCore, Long inputSize, + Float quantizationConfigLow, Float quantizationConfigHigh, Long quantizationConfigNumBuckets, + String tableName) { + OperationBuilder opBuilder = scope.opBuilder(OP_NAME, "XlaSparseDenseMatmulWithCsrInput"); + opBuilder.addInput(rowPointers.asOutput()); + opBuilder.addInput(sortedSampleIds.asOutput()); + opBuilder.addInput(sortedTokenIds.asOutput()); + opBuilder.addInput(sortedGains.asOutput()); + opBuilder.addInput(embeddingTable.asOutput()); + opBuilder.addInput(numMinibatchesPerPhysicalSparseCore.asOutput()); + opBuilder.setAttr("input_size", inputSize); + opBuilder.setAttr("quantization_config_low", quantizationConfigLow); + opBuilder.setAttr("quantization_config_high", quantizationConfigHigh); + opBuilder.setAttr("quantization_config_num_buckets", quantizationConfigNumBuckets); + opBuilder.setAttr("table_name", tableName); + return new XlaSparseDenseMatmulWithCsrInput(opBuilder.build()); + } + + /** + * Gets activations. + * + * @return activations. + */ + public Output activations() { + return activations; + } + + @Override + public Output asOutput() { + return activations; + } + + @OpInputsMetadata( + outputsClass = XlaSparseDenseMatmulWithCsrInput.class + ) + public static class Inputs extends RawOpInputs { + /** + * The rowPointers input + */ + public final Operand rowPointers; + + /** + * The sortedSampleIds input + */ + public final Operand sortedSampleIds; + + /** + * The sortedTokenIds input + */ + public final Operand sortedTokenIds; + + /** + * The sortedGains input + */ + public final Operand sortedGains; + + /** + * The embeddingTable input + */ + public final Operand embeddingTable; + + /** + * The numMinibatchesPerPhysicalSparseCore input + */ + public final Operand numMinibatchesPerPhysicalSparseCore; + + /** + * The inputSize attribute + */ + public final long inputSize; + + /** + * The quantizationConfigLow attribute + */ + public final float quantizationConfigLow; + + /** + * The quantizationConfigHigh attribute + */ + public final float quantizationConfigHigh; + + /** + * The quantizationConfigNumBuckets attribute + */ + public final long quantizationConfigNumBuckets; + + /** + * The tableName attribute + */ + public final String tableName; + + public Inputs(GraphOperation op) { + super(new XlaSparseDenseMatmulWithCsrInput(op), op, Arrays.asList("input_size", "quantization_config_low", "quantization_config_high", "quantization_config_num_buckets", "table_name")); + int inputIndex = 0; + rowPointers = (Operand) op.input(inputIndex++); + sortedSampleIds = (Operand) op.input(inputIndex++); + sortedTokenIds = (Operand) op.input(inputIndex++); + sortedGains = (Operand) op.input(inputIndex++); + embeddingTable = (Operand) op.input(inputIndex++); + numMinibatchesPerPhysicalSparseCore = (Operand) op.input(inputIndex++); + inputSize = op.attributes().getAttrInt("input_size"); + quantizationConfigLow = op.attributes().getAttrFloat("quantization_config_low"); + quantizationConfigHigh = op.attributes().getAttrFloat("quantization_config_high"); + quantizationConfigNumBuckets = op.attributes().getAttrInt("quantization_config_num_buckets"); + tableName = op.attributes().getAttrString("table_name"); + } + } +} diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java index 17607b2e937..21935032de4 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/OpGenerator.java @@ -316,7 +316,7 @@ private void createApiDef(OpDef opDef, File apiDefFile) throws IOException { ApiDef.Visibility visibility = null; do { System.out.print( - " Choose visibility of this op [v]isible/[h]idden/[s]kip/[d]efault (default=d): "); + " Choose visibility of this op [v]isible/[h]idden/[s]kip/[d]efault (default=v): "); var value = USER_PROMPT.nextLine().trim(); if (!value.isEmpty()) { switch (value) { @@ -343,7 +343,7 @@ private void createApiDef(OpDef opDef, File apiDefFile) throws IOException { break; } } else { - visibility = ApiDef.Visibility.DEFAULT_VISIBILITY; + visibility = ApiDef.Visibility.VISIBLE; } } while (visibility == null); diff --git a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/javadoc/CoreJavaDocNodeRenderer.java b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/javadoc/CoreJavaDocNodeRenderer.java index 82be0136b2b..58da60c7d8b 100644 --- a/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/javadoc/CoreJavaDocNodeRenderer.java +++ b/tensorflow-core/tensorflow-core-generator/src/main/java/org/tensorflow/generator/op/javadoc/CoreJavaDocNodeRenderer.java @@ -3,7 +3,6 @@ import com.google.common.base.CaseFormat; import java.util.Arrays; import java.util.Collections; -import java.util.HashMap; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.Map; @@ -152,14 +151,10 @@ public class CoreJavaDocNodeRenderer extends AbstractVisitor implements NodeRend }; private static final Set allowedHtml5Tags = new HashSet<>(Arrays.asList(html5Tags)); private static final Map urlLinkConversion = - new HashMap() { - { - put("../../../api_docs/python/math_ops", "org.tensorflow.op.MathOps"); - put( - "https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update", + Map.of( + "../../../api_docs/python/math_ops", "org.tensorflow.op.MathOps", + "https://www.tensorflow.org/api_docs/python/tf/tensor_scatter_nd_update", "org.tensorflow.op.Ops#tensorScatterNdUpdate"); - } - }; protected final JavaDocNodeRendererContext context; private final JavaDocWriter writer; private boolean firstParagraph; diff --git a/tensorflow-core/tensorflow-core-native/.bazelrc b/tensorflow-core/tensorflow-core-native/.bazelrc index 8c72fa1398b..7437d982e91 100644 --- a/tensorflow-core/tensorflow-core-native/.bazelrc +++ b/tensorflow-core/tensorflow-core-native/.bazelrc @@ -1,3 +1,5 @@ -#build --remote_cache=https://storage.googleapis.com/tensorflow-sigs-jvm -#build --remote_upload_local_results=false build --experimental_ui_max_stdouterr_bytes=-1 + +# We don't need this cache, since we are not building TF from GitHub Actions anymore. Just keeping this here for reference +# build --remote_cache=https://storage.googleapis.com/tensorflow-sigs-jvm +# build --remote_upload_local_results=false diff --git a/tensorflow-core/tensorflow-core-native/.bazelversion b/tensorflow-core/tensorflow-core-native/.bazelversion index 358e78e6074..f3c238740e5 100644 --- a/tensorflow-core/tensorflow-core-native/.bazelversion +++ b/tensorflow-core/tensorflow-core-native/.bazelversion @@ -1 +1,2 @@ -6.1.0 \ No newline at end of file +6.5.0 +# NOTE: Update Bazel version in tensorflow/tools/ci_build/release/common.sh.oss \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-native/WORKSPACE b/tensorflow-core/tensorflow-core-native/WORKSPACE index ad2c878d529..93acef27be1 100644 --- a/tensorflow-core/tensorflow-core-native/WORKSPACE +++ b/tensorflow-core/tensorflow-core-native/WORKSPACE @@ -18,10 +18,10 @@ http_archive( "find tensorflow third_party/xla/third_party/tsl -name \\*.proto | xargs sed -i.bak 's/^package tensorflow\\([^;]*\\).*$/package tensorflow\\1;\\noption java_package = \"org.tensorflow.proto\\1\";/'", ], urls = [ - "https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.15.0.tar.gz", + "https://github.com/tensorflow/tensorflow/archive/refs/tags/v2.16.1.tar.gz", ], - sha256 = "9cec5acb0ecf2d47b16891f8bc5bc6fbfdffe1700bdadc0d9ebe27ea34f0c220", - strip_prefix = "tensorflow-2.15.0" + sha256 = "c729e56efc945c6df08efe5c9f5b8b89329c7c91b8f40ad2bb3e13900bd4876d", + strip_prefix = "tensorflow-2.16.1" ) ##### Copy content of tensorflow/WORKSPACE here (make sure to change references of default package "//" to "@org_tensorflow//") @@ -57,12 +57,12 @@ load( python_repository(name = "python_version_repo") -load("@python_version_repo//:py_version.bzl", "HERMETIC_PYTHON_VERSION") +load("@python_version_repo//:py_version.bzl", "TF_PYTHON_VERSION") python_register_toolchains( name = "python", ignore_root_user_error = True, - python_version = HERMETIC_PYTHON_VERSION, + python_version = TF_PYTHON_VERSION, ) load("@python//:defs.bzl", "interpreter") @@ -115,4 +115,4 @@ tf_workspace1() load("@org_tensorflow//tensorflow:workspace0.bzl", "tf_workspace0") -tf_workspace0() +tf_workspace0() \ No newline at end of file diff --git a/tensorflow-core/tensorflow-core-native/pom.xml b/tensorflow-core/tensorflow-core-native/pom.xml index 00c67371547..066ca12be97 100644 --- a/tensorflow-core/tensorflow-core-native/pom.xml +++ b/tensorflow-core/tensorflow-core-native/pom.xml @@ -69,7 +69,9 @@ false false + false false + false true diff --git a/tensorflow-core/tensorflow-core-native/scripts/dist_download.sh b/tensorflow-core/tensorflow-core-native/scripts/dist_download.sh index a97220538c3..c5b10d50d48 100755 --- a/tensorflow-core/tensorflow-core-native/scripts/dist_download.sh +++ b/tensorflow-core/tensorflow-core-native/scripts/dist_download.sh @@ -5,20 +5,20 @@ DOWNLOAD_FOLDER="$1" case ${PLATFORM:-} in 'linux-x86_64') - WHEEL_URL='https://files.pythonhosted.org/packages/fa/44/a1698c62942d20cab378ba201a6cbfcce579418351a0c6e4ea9d66c9adf2/tensorflow_cpu-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl' + WHEEL_URL='https://files.pythonhosted.org/packages/9b/0b/c18d6464a19d4c9b63df8880dd3ce0c67b5145ada9092f3ac67d82726566/tensorflow_cpu-2.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl' ;; 'linux-x86_64-gpu') - WHEEL_URL='https://files.pythonhosted.org/packages/93/c0/a774286d0383419f558deb27096e5de9f9facd6c27df8e9f9af6fba2f77e/tensorflow-2.15.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl' + WHEEL_URL='https://files.pythonhosted.org/packages/58/70/e8ac764ec80810eefcbab0cb1d21dbba6cf26719c44cd6d9a5e9f0407935/tensorflow-2.16.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl' ;; 'macosx-x86_64') - WHEEL_URL='https://files.pythonhosted.org/packages/92/2d/880fcd65e4414b05088193e6f2cfb86fdf90003dd2dd0f4d1bc465348f0e/tensorflow-2.15.0-cp311-cp311-macosx_10_15_x86_64.whl' + WHEEL_URL='https://files.pythonhosted.org/packages/b8/75/ce4d8eeb1fb100726634358411bc4a8b12f889f6ce560b0973c0a5dbac39/tensorflow-2.16.1-cp311-cp311-macosx_10_15_x86_64.whl' ;; 'macosx-arm64') - WHEEL_URL='https://files.pythonhosted.org/packages/eb/9f/0759e2fea4a3c48f070b64811c2c57036b46353ba87263afc810b8f4188a/tensorflow_macos-2.15.0-cp311-cp311-macosx_12_0_arm64.whl' + WHEEL_URL='https://files.pythonhosted.org/packages/f9/14/67e9b2b2379cb530c0412123a674d045eca387dfcfa7db1c0028857b0a66/tensorflow-2.16.1-cp311-cp311-macosx_12_0_arm64.whl' ;; 'windows-x86_64') - WHEEL_URL='https://files.pythonhosted.org/packages/4c/48/1a5a15517f18eaa4ff8d598b1c000300b20c1bb0e624539d702117a0c369/tensorflow_intel-2.15.0-cp311-cp311-win_amd64.whl' - CLIB_URL='https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-cpu-windows-x86_64-2.15.0.zip' + WHEEL_URL='https://files.pythonhosted.org/packages/e0/36/6278e4e7e69a90c00e0f82944d8f2713dd85a69d1add455d9e50446837ab/tensorflow_intel-2.16.1-cp311-cp311-win_amd64.whl' + CLIB_URL='https://storage.googleapis.com/tensorflow/versions/2.16.1/libtensorflow-cpu-windows-x86_64.zip' ;; *) echo "TensorFlow distribution for ${PLATFORM} is not supported for download" diff --git a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java index c5be5971ff1..83b6f22e528 100644 --- a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java +++ b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/internal/c_api/global/tensorflow.java @@ -517,6 +517,7 @@ public static native void TF_TString_Copy(TF_TString dst, String src, // -------------------------------------------------------------------------- // TF_Code holds an error code. The enum values here are identical to // corresponding values in error_codes.proto. +// LINT.IfChange public static final int TF_OK = TSL_OK; public static final int TF_CANCELLED = TSL_CANCELLED; public static final int TF_UNKNOWN = TSL_UNKNOWN; @@ -534,6 +535,7 @@ public static native void TF_TString_Copy(TF_TString dst, String src, public static final int TF_INTERNAL = TSL_INTERNAL; public static final int TF_UNAVAILABLE = TSL_UNAVAILABLE; public static final int TF_DATA_LOSS = TSL_DATA_LOSS; +// LINT.ThenChange(//tensorflow/python/py_exception_registry_wrapper.cc) // -------------------------------------------------------------------------- @@ -2143,6 +2145,13 @@ public static native TF_ImportGraphDefResults TF_GraphImportGraphDefWithResults( @Const TF_ImportGraphDefOptions options, TF_Status status); +// Has the same behavior as TF_GraphImportGraphDefWithResults, but instead of +// taking in a serialized tensorflow::GraphDef, it takes in a *pointer* to the +// C++ *in memory representation* of the GraphDef, stored in `graph_def->data` +public static native TF_ImportGraphDefResults TF_GraphImportGraphDefWithResultsNoSerialization( + TF_Graph graph, @Const TF_Buffer graph_def, + @Const TF_ImportGraphDefOptions options, TF_Status status); + // Import the graph serialized in `graph_def` into `graph`. // Convenience function for when only return outputs are needed. // diff --git a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProto.java b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProto.java index 889a6aa39aa..04d15e4a308 100644 --- a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProto.java +++ b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProto.java @@ -344,6 +344,60 @@ public interface ExperimentalOrBuilder extends */ boolean getUseTfrt(); + /** + *

+     * If true, use Pathways with TFRT API for multi host support.
+     * 
+ * + * bool enable_multi_host = 27; + * @return The enableMultiHost. + */ + boolean getEnableMultiHost(); + + /** + *
+     * Port for the Pathways server. Ignored if enable_multi_host=false.
+     * 
+ * + * int32 backend_server_port = 28; + * @return The backendServerPort. + */ + int getBackendServerPort(); + + /** + *
+     * If true, TFRT will use TPU specific compiler passes and perform TPU
+     * specific initialization.
+     * 
+ * + * bool target_tpu = 29; + * @return The targetTpu. + */ + boolean getTargetTpu(); + + /** + *
+     * If true, TFRT will use GPU specific compiler passes and perform GPU
+     * specific initialization.
+     * 
+ * + * bool target_gpu = 30; + * @return The targetGpu. + */ + boolean getTargetGpu(); + + /** + *
+     * The threshold to merge small streams in TFRT. The stream with cost
+     * smaller than the threshold will be merged. Setting it to value 1
+     * disables all merges.
+     * 
+ * + * int32 stream_merge_threshold = 31; + * @return The streamMergeThreshold. + */ + int getStreamMergeThreshold(); + /** *
      * Whether functional control flow op lowering should be disabled. This is
@@ -1032,6 +1086,85 @@ public boolean getUseTfrt() {
       return useTfrt_;
     }
 
+    public static final int ENABLE_MULTI_HOST_FIELD_NUMBER = 27;
+    private boolean enableMultiHost_;
+    /**
+     * 
+     * If true, use Pathways with TFRT API for multi host support.
+     * 
+ * + * bool enable_multi_host = 27; + * @return The enableMultiHost. + */ + @java.lang.Override + public boolean getEnableMultiHost() { + return enableMultiHost_; + } + + public static final int BACKEND_SERVER_PORT_FIELD_NUMBER = 28; + private int backendServerPort_; + /** + *
+     * Port for the Pathways server. Ignored if enable_multi_host=false.
+     * 
+ * + * int32 backend_server_port = 28; + * @return The backendServerPort. + */ + @java.lang.Override + public int getBackendServerPort() { + return backendServerPort_; + } + + public static final int TARGET_TPU_FIELD_NUMBER = 29; + private boolean targetTpu_; + /** + *
+     * If true, TFRT will use TPU specific compiler passes and perform TPU
+     * specific initialization.
+     * 
+ * + * bool target_tpu = 29; + * @return The targetTpu. + */ + @java.lang.Override + public boolean getTargetTpu() { + return targetTpu_; + } + + public static final int TARGET_GPU_FIELD_NUMBER = 30; + private boolean targetGpu_; + /** + *
+     * If true, TFRT will use GPU specific compiler passes and perform GPU
+     * specific initialization.
+     * 
+ * + * bool target_gpu = 30; + * @return The targetGpu. + */ + @java.lang.Override + public boolean getTargetGpu() { + return targetGpu_; + } + + public static final int STREAM_MERGE_THRESHOLD_FIELD_NUMBER = 31; + private int streamMergeThreshold_; + /** + *
+     * The threshold to merge small streams in TFRT. The stream with cost
+     * smaller than the threshold will be merged. Setting it to value 1
+     * disables all merges.
+     * 
+ * + * int32 stream_merge_threshold = 31; + * @return The streamMergeThreshold. + */ + @java.lang.Override + public int getStreamMergeThreshold() { + return streamMergeThreshold_; + } + public static final int DISABLE_FUNCTIONAL_OPS_LOWERING_FIELD_NUMBER = 21; private boolean disableFunctionalOpsLowering_; /** @@ -1221,6 +1354,21 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) if (disableEagerExecutorStreamingEnqueue_ != false) { output.writeBool(26, disableEagerExecutorStreamingEnqueue_); } + if (enableMultiHost_ != false) { + output.writeBool(27, enableMultiHost_); + } + if (backendServerPort_ != 0) { + output.writeInt32(28, backendServerPort_); + } + if (targetTpu_ != false) { + output.writeBool(29, targetTpu_); + } + if (targetGpu_ != false) { + output.writeBool(30, targetGpu_); + } + if (streamMergeThreshold_ != 0) { + output.writeInt32(31, streamMergeThreshold_); + } getUnknownFields().writeTo(output); } @@ -1316,6 +1464,26 @@ public int getSerializedSize() { size += com.google.protobuf.CodedOutputStream .computeBoolSize(26, disableEagerExecutorStreamingEnqueue_); } + if (enableMultiHost_ != false) { + size += com.google.protobuf.CodedOutputStream + .computeBoolSize(27, enableMultiHost_); + } + if (backendServerPort_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(28, backendServerPort_); + } + if (targetTpu_ != false) { + size += com.google.protobuf.CodedOutputStream + .computeBoolSize(29, targetTpu_); + } + if (targetGpu_ != false) { + size += com.google.protobuf.CodedOutputStream + .computeBoolSize(30, targetGpu_); + } + if (streamMergeThreshold_ != 0) { + size += com.google.protobuf.CodedOutputStream + .computeInt32Size(31, streamMergeThreshold_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -1367,6 +1535,16 @@ public boolean equals(final java.lang.Object obj) { != other.getXlaFusionAutotunerThresh()) return false; if (getUseTfrt() != other.getUseTfrt()) return false; + if (getEnableMultiHost() + != other.getEnableMultiHost()) return false; + if (getBackendServerPort() + != other.getBackendServerPort()) return false; + if (getTargetTpu() + != other.getTargetTpu()) return false; + if (getTargetGpu() + != other.getTargetGpu()) return false; + if (getStreamMergeThreshold() + != other.getStreamMergeThreshold()) return false; if (getDisableFunctionalOpsLowering() != other.getDisableFunctionalOpsLowering()) return false; if (getXlaPreferSingleGraphCluster() @@ -1439,6 +1617,19 @@ public int hashCode() { hash = (37 * hash) + USE_TFRT_FIELD_NUMBER; hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( getUseTfrt()); + hash = (37 * hash) + ENABLE_MULTI_HOST_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( + getEnableMultiHost()); + hash = (37 * hash) + BACKEND_SERVER_PORT_FIELD_NUMBER; + hash = (53 * hash) + getBackendServerPort(); + hash = (37 * hash) + TARGET_TPU_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( + getTargetTpu()); + hash = (37 * hash) + TARGET_GPU_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( + getTargetGpu()); + hash = (37 * hash) + STREAM_MERGE_THRESHOLD_FIELD_NUMBER; + hash = (53 * hash) + getStreamMergeThreshold(); hash = (37 * hash) + DISABLE_FUNCTIONAL_OPS_LOWERING_FIELD_NUMBER; hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( getDisableFunctionalOpsLowering()); @@ -1627,6 +1818,16 @@ public Builder clear() { useTfrt_ = false; + enableMultiHost_ = false; + + backendServerPort_ = 0; + + targetTpu_ = false; + + targetGpu_ = false; + + streamMergeThreshold_ = 0; + disableFunctionalOpsLowering_ = false; xlaPreferSingleGraphCluster_ = false; @@ -1688,6 +1889,11 @@ public org.tensorflow.proto.ConfigProto.Experimental buildPartial() { result.disableOutputPartitionGraphs_ = disableOutputPartitionGraphs_; result.xlaFusionAutotunerThresh_ = xlaFusionAutotunerThresh_; result.useTfrt_ = useTfrt_; + result.enableMultiHost_ = enableMultiHost_; + result.backendServerPort_ = backendServerPort_; + result.targetTpu_ = targetTpu_; + result.targetGpu_ = targetGpu_; + result.streamMergeThreshold_ = streamMergeThreshold_; result.disableFunctionalOpsLowering_ = disableFunctionalOpsLowering_; result.xlaPreferSingleGraphCluster_ = xlaPreferSingleGraphCluster_; if (coordinationConfigBuilder_ == null) { @@ -1798,6 +2004,21 @@ public Builder mergeFrom(org.tensorflow.proto.ConfigProto.Experimental other) { if (other.getUseTfrt() != false) { setUseTfrt(other.getUseTfrt()); } + if (other.getEnableMultiHost() != false) { + setEnableMultiHost(other.getEnableMultiHost()); + } + if (other.getBackendServerPort() != 0) { + setBackendServerPort(other.getBackendServerPort()); + } + if (other.getTargetTpu() != false) { + setTargetTpu(other.getTargetTpu()); + } + if (other.getTargetGpu() != false) { + setTargetGpu(other.getTargetGpu()); + } + if (other.getStreamMergeThreshold() != 0) { + setStreamMergeThreshold(other.getStreamMergeThreshold()); + } if (other.getDisableFunctionalOpsLowering() != false) { setDisableFunctionalOpsLowering(other.getDisableFunctionalOpsLowering()); } @@ -1953,6 +2174,31 @@ public Builder mergeFrom( break; } // case 208 + case 216: { + enableMultiHost_ = input.readBool(); + + break; + } // case 216 + case 224: { + backendServerPort_ = input.readInt32(); + + break; + } // case 224 + case 232: { + targetTpu_ = input.readBool(); + + break; + } // case 232 + case 240: { + targetGpu_ = input.readBool(); + + break; + } // case 240 + case 248: { + streamMergeThreshold_ = input.readInt32(); + + break; + } // case 248 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -3134,6 +3380,233 @@ public Builder clearUseTfrt() { return this; } + private boolean enableMultiHost_ ; + /** + *
+       * If true, use Pathways with TFRT API for multi host support.
+       * 
+ * + * bool enable_multi_host = 27; + * @return The enableMultiHost. + */ + @java.lang.Override + public boolean getEnableMultiHost() { + return enableMultiHost_; + } + /** + *
+       * If true, use Pathways with TFRT API for multi host support.
+       * 
+ * + * bool enable_multi_host = 27; + * @param value The enableMultiHost to set. + * @return This builder for chaining. + */ + public Builder setEnableMultiHost(boolean value) { + + enableMultiHost_ = value; + onChanged(); + return this; + } + /** + *
+       * If true, use Pathways with TFRT API for multi host support.
+       * 
+ * + * bool enable_multi_host = 27; + * @return This builder for chaining. + */ + public Builder clearEnableMultiHost() { + + enableMultiHost_ = false; + onChanged(); + return this; + } + + private int backendServerPort_ ; + /** + *
+       * Port for the Pathways server. Ignored if enable_multi_host=false.
+       * 
+ * + * int32 backend_server_port = 28; + * @return The backendServerPort. + */ + @java.lang.Override + public int getBackendServerPort() { + return backendServerPort_; + } + /** + *
+       * Port for the Pathways server. Ignored if enable_multi_host=false.
+       * 
+ * + * int32 backend_server_port = 28; + * @param value The backendServerPort to set. + * @return This builder for chaining. + */ + public Builder setBackendServerPort(int value) { + + backendServerPort_ = value; + onChanged(); + return this; + } + /** + *
+       * Port for the Pathways server. Ignored if enable_multi_host=false.
+       * 
+ * + * int32 backend_server_port = 28; + * @return This builder for chaining. + */ + public Builder clearBackendServerPort() { + + backendServerPort_ = 0; + onChanged(); + return this; + } + + private boolean targetTpu_ ; + /** + *
+       * If true, TFRT will use TPU specific compiler passes and perform TPU
+       * specific initialization.
+       * 
+ * + * bool target_tpu = 29; + * @return The targetTpu. + */ + @java.lang.Override + public boolean getTargetTpu() { + return targetTpu_; + } + /** + *
+       * If true, TFRT will use TPU specific compiler passes and perform TPU
+       * specific initialization.
+       * 
+ * + * bool target_tpu = 29; + * @param value The targetTpu to set. + * @return This builder for chaining. + */ + public Builder setTargetTpu(boolean value) { + + targetTpu_ = value; + onChanged(); + return this; + } + /** + *
+       * If true, TFRT will use TPU specific compiler passes and perform TPU
+       * specific initialization.
+       * 
+ * + * bool target_tpu = 29; + * @return This builder for chaining. + */ + public Builder clearTargetTpu() { + + targetTpu_ = false; + onChanged(); + return this; + } + + private boolean targetGpu_ ; + /** + *
+       * If true, TFRT will use GPU specific compiler passes and perform GPU
+       * specific initialization.
+       * 
+ * + * bool target_gpu = 30; + * @return The targetGpu. + */ + @java.lang.Override + public boolean getTargetGpu() { + return targetGpu_; + } + /** + *
+       * If true, TFRT will use GPU specific compiler passes and perform GPU
+       * specific initialization.
+       * 
+ * + * bool target_gpu = 30; + * @param value The targetGpu to set. + * @return This builder for chaining. + */ + public Builder setTargetGpu(boolean value) { + + targetGpu_ = value; + onChanged(); + return this; + } + /** + *
+       * If true, TFRT will use GPU specific compiler passes and perform GPU
+       * specific initialization.
+       * 
+ * + * bool target_gpu = 30; + * @return This builder for chaining. + */ + public Builder clearTargetGpu() { + + targetGpu_ = false; + onChanged(); + return this; + } + + private int streamMergeThreshold_ ; + /** + *
+       * The threshold to merge small streams in TFRT. The stream with cost
+       * smaller than the threshold will be merged. Setting it to value 1
+       * disables all merges.
+       * 
+ * + * int32 stream_merge_threshold = 31; + * @return The streamMergeThreshold. + */ + @java.lang.Override + public int getStreamMergeThreshold() { + return streamMergeThreshold_; + } + /** + *
+       * The threshold to merge small streams in TFRT. The stream with cost
+       * smaller than the threshold will be merged. Setting it to value 1
+       * disables all merges.
+       * 
+ * + * int32 stream_merge_threshold = 31; + * @param value The streamMergeThreshold to set. + * @return This builder for chaining. + */ + public Builder setStreamMergeThreshold(int value) { + + streamMergeThreshold_ = value; + onChanged(); + return this; + } + /** + *
+       * The threshold to merge small streams in TFRT. The stream with cost
+       * smaller than the threshold will be merged. Setting it to value 1
+       * disables all merges.
+       * 
+ * + * int32 stream_merge_threshold = 31; + * @return This builder for chaining. + */ + public Builder clearStreamMergeThreshold() { + + streamMergeThreshold_ = 0; + onChanged(); + return this; + } + private boolean disableFunctionalOpsLowering_ ; /** *
diff --git a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProtos.java b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProtos.java
index 9f8c8190e60..bca6f96f8b0 100644
--- a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProtos.java
+++ b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/ConfigProtos.java
@@ -175,7 +175,7 @@ public static void registerAllExtensions(
       "ession_elimination\"A\n\025ThreadPoolOptionPr" +
       "oto\022\023\n\013num_threads\030\001 \001(\005\022\023\n\013global_name\030" +
       "\002 \001(\t\"0\n\017SessionMetadata\022\014\n\004name\030\001 \001(\t\022\017" +
-      "\n\007version\030\002 \001(\003\"\225\017\n\013ConfigProto\022>\n\014devic" +
+      "\n\007version\030\002 \001(\003\"\225\020\n\013ConfigProto\022>\n\014devic" +
       "e_count\030\001 \003(\0132(.tensorflow.ConfigProto.D" +
       "eviceCountEntry\022$\n\034intra_op_parallelism_" +
       "threads\030\002 \001(\005\022$\n\034inter_op_parallelism_th" +
@@ -194,7 +194,7 @@ public static void registerAllExtensions(
       "hare_cluster_devices_in_session\030\021 \001(\010\022:\n" +
       "\014experimental\030\020 \001(\0132$.tensorflow.ConfigP" +
       "roto.Experimental\0322\n\020DeviceCountEntry\022\013\n" +
-      "\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\005:\0028\001\032\217\t\n\014Experi" +
+      "\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\005:\0028\001\032\217\n\n\014Experi" +
       "mental\022\037\n\027collective_group_leader\030\001 \001(\t\022" +
       "\025\n\rexecutor_type\030\003 \001(\t\022\032\n\022recv_buf_max_c" +
       "hunk\030\004 \001(\005\022\031\n\021use_numa_affinity\030\005 \001(\010\0225\n" +
@@ -211,62 +211,65 @@ public static void registerAllExtensions(
       "idgeRollout\022&\n\036enable_mlir_graph_optimiz" +
       "ation\030\020 \001(\010\022\'\n\037disable_output_partition_" +
       "graphs\030\016 \001(\010\022#\n\033xla_fusion_autotuner_thr" +
-      "esh\030\017 \001(\003\022\020\n\010use_tfrt\030\022 \001(\010\022\'\n\037disable_f" +
-      "unctional_ops_lowering\030\025 \001(\010\022\'\n\037xla_pref" +
-      "er_single_graph_cluster\030\026 \001(\010\022B\n\023coordin" +
-      "ation_config\030\027 \001(\0132%.tensorflow.Coordina" +
-      "tionServiceConfig\022)\n!disable_optimize_fo" +
-      "r_static_graph\030\030 \001(\010\0220\n(disable_eager_ex" +
-      "ecutor_streaming_enqueue\030\032 \001(\010\"\336\001\n\021MlirB" +
-      "ridgeRollout\022#\n\037MLIR_BRIDGE_ROLLOUT_UNSP" +
-      "ECIFIED\020\000\022\037\n\033MLIR_BRIDGE_ROLLOUT_ENABLED" +
-      "\020\001\022 \n\034MLIR_BRIDGE_ROLLOUT_DISABLED\020\002\"\004\010\003" +
-      "\020\003\"\004\010\004\020\004*%MLIR_BRIDGE_ROLLOUT_SAFE_MODE_" +
-      "ENABLED*.MLIR_BRIDGE_ROLLOUT_SAFE_MODE_F" +
-      "ALLBACK_ENABLEDJ\004\010\002\020\003J\004\010\023\020\024J\004\010\024\020\025J\004\010\031\020\032\"" +
-      "\341\004\n\nRunOptions\0226\n\013trace_level\030\001 \001(\0162!.te" +
-      "nsorflow.RunOptions.TraceLevel\022\025\n\rtimeou" +
-      "t_in_ms\030\002 \001(\003\022\034\n\024inter_op_thread_pool\030\003 " +
-      "\001(\005\022\037\n\027output_partition_graphs\030\005 \001(\010\022/\n\r" +
-      "debug_options\030\006 \001(\0132\030.tensorflow.DebugOp" +
-      "tions\022*\n\"report_tensor_allocations_upon_" +
-      "oom\030\007 \001(\010\0229\n\014experimental\030\010 \001(\0132#.tensor" +
-      "flow.RunOptions.Experimental\032\322\001\n\014Experim" +
-      "ental\022\034\n\024collective_graph_key\030\001 \001(\003\022\034\n\024u" +
-      "se_run_handler_pool\030\002 \001(\010\022[\n\030run_handler" +
-      "_pool_options\030\003 \001(\01329.tensorflow.RunOpti" +
-      "ons.Experimental.RunHandlerPoolOptions\032)" +
-      "\n\025RunHandlerPoolOptions\022\020\n\010priority\030\001 \001(" +
-      "\003\"R\n\nTraceLevel\022\014\n\010NO_TRACE\020\000\022\022\n\016SOFTWAR" +
-      "E_TRACE\020\001\022\022\n\016HARDWARE_TRACE\020\002\022\016\n\nFULL_TR" +
-      "ACE\020\003J\004\010\004\020\005\"\276\003\n\013RunMetadata\022)\n\nstep_stat" +
-      "s\030\001 \001(\0132\025.tensorflow.StepStats\022,\n\ncost_g" +
-      "raph\030\002 \001(\0132\030.tensorflow.CostGraphDef\022.\n\020" +
-      "partition_graphs\030\003 \003(\0132\024.tensorflow.Grap" +
-      "hDef\022?\n\017function_graphs\030\004 \003(\0132&.tensorfl" +
-      "ow.RunMetadata.FunctionGraphs\0225\n\020session" +
-      "_metadata\030\005 \001(\0132\033.tensorflow.SessionMeta" +
-      "data\032\255\001\n\016FunctionGraphs\022.\n\020partition_gra" +
-      "phs\030\001 \003(\0132\024.tensorflow.GraphDef\0224\n\026pre_o" +
-      "ptimization_graph\030\002 \001(\0132\024.tensorflow.Gra" +
-      "phDef\0225\n\027post_optimization_graph\030\003 \001(\0132\024" +
-      ".tensorflow.GraphDef\":\n\020TensorConnection" +
-      "\022\023\n\013from_tensor\030\001 \001(\t\022\021\n\tto_tensor\030\002 \001(\t" +
-      "\"\260\003\n\017CallableOptions\022\014\n\004feed\030\001 \003(\t\022\r\n\005fe" +
-      "tch\030\002 \003(\t\022\016\n\006target\030\003 \003(\t\022+\n\013run_options" +
-      "\030\004 \001(\0132\026.tensorflow.RunOptions\0227\n\021tensor" +
-      "_connection\030\005 \003(\0132\034.tensorflow.TensorCon" +
-      "nection\022B\n\014feed_devices\030\006 \003(\0132,.tensorfl" +
-      "ow.CallableOptions.FeedDevicesEntry\022D\n\rf" +
-      "etch_devices\030\007 \003(\0132-.tensorflow.Callable" +
-      "Options.FetchDevicesEntry\022\027\n\017fetch_skip_" +
-      "sync\030\010 \001(\010\0322\n\020FeedDevicesEntry\022\013\n\003key\030\001 " +
-      "\001(\t\022\r\n\005value\030\002 \001(\t:\0028\001\0323\n\021FetchDevicesEn" +
-      "try\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t:\0028\001B\200\001\n\024" +
-      "org.tensorflow.protoB\014ConfigProtosP\001ZUgi" +
-      "thub.com/tensorflow/tensorflow/tensorflo" +
-      "w/go/core/protobuf/for_core_protos_go_pr" +
-      "oto\370\001\001b\006proto3"
+      "esh\030\017 \001(\003\022\020\n\010use_tfrt\030\022 \001(\010\022\031\n\021enable_mu" +
+      "lti_host\030\033 \001(\010\022\033\n\023backend_server_port\030\034 " +
+      "\001(\005\022\022\n\ntarget_tpu\030\035 \001(\010\022\022\n\ntarget_gpu\030\036 " +
+      "\001(\010\022\036\n\026stream_merge_threshold\030\037 \001(\005\022\'\n\037d" +
+      "isable_functional_ops_lowering\030\025 \001(\010\022\'\n\037" +
+      "xla_prefer_single_graph_cluster\030\026 \001(\010\022B\n" +
+      "\023coordination_config\030\027 \001(\0132%.tensorflow." +
+      "CoordinationServiceConfig\022)\n!disable_opt" +
+      "imize_for_static_graph\030\030 \001(\010\0220\n(disable_" +
+      "eager_executor_streaming_enqueue\030\032 \001(\010\"\336" +
+      "\001\n\021MlirBridgeRollout\022#\n\037MLIR_BRIDGE_ROLL" +
+      "OUT_UNSPECIFIED\020\000\022\037\n\033MLIR_BRIDGE_ROLLOUT" +
+      "_ENABLED\020\001\022 \n\034MLIR_BRIDGE_ROLLOUT_DISABL" +
+      "ED\020\002\"\004\010\003\020\003\"\004\010\004\020\004*%MLIR_BRIDGE_ROLLOUT_SA" +
+      "FE_MODE_ENABLED*.MLIR_BRIDGE_ROLLOUT_SAF" +
+      "E_MODE_FALLBACK_ENABLEDJ\004\010\002\020\003J\004\010\023\020\024J\004\010\024\020" +
+      "\025J\004\010\031\020\032\"\341\004\n\nRunOptions\0226\n\013trace_level\030\001 " +
+      "\001(\0162!.tensorflow.RunOptions.TraceLevel\022\025" +
+      "\n\rtimeout_in_ms\030\002 \001(\003\022\034\n\024inter_op_thread" +
+      "_pool\030\003 \001(\005\022\037\n\027output_partition_graphs\030\005" +
+      " \001(\010\022/\n\rdebug_options\030\006 \001(\0132\030.tensorflow" +
+      ".DebugOptions\022*\n\"report_tensor_allocatio" +
+      "ns_upon_oom\030\007 \001(\010\0229\n\014experimental\030\010 \001(\0132" +
+      "#.tensorflow.RunOptions.Experimental\032\322\001\n" +
+      "\014Experimental\022\034\n\024collective_graph_key\030\001 " +
+      "\001(\003\022\034\n\024use_run_handler_pool\030\002 \001(\010\022[\n\030run" +
+      "_handler_pool_options\030\003 \001(\01329.tensorflow" +
+      ".RunOptions.Experimental.RunHandlerPoolO" +
+      "ptions\032)\n\025RunHandlerPoolOptions\022\020\n\010prior" +
+      "ity\030\001 \001(\003\"R\n\nTraceLevel\022\014\n\010NO_TRACE\020\000\022\022\n" +
+      "\016SOFTWARE_TRACE\020\001\022\022\n\016HARDWARE_TRACE\020\002\022\016\n" +
+      "\nFULL_TRACE\020\003J\004\010\004\020\005\"\276\003\n\013RunMetadata\022)\n\ns" +
+      "tep_stats\030\001 \001(\0132\025.tensorflow.StepStats\022," +
+      "\n\ncost_graph\030\002 \001(\0132\030.tensorflow.CostGrap" +
+      "hDef\022.\n\020partition_graphs\030\003 \003(\0132\024.tensorf" +
+      "low.GraphDef\022?\n\017function_graphs\030\004 \003(\0132&." +
+      "tensorflow.RunMetadata.FunctionGraphs\0225\n" +
+      "\020session_metadata\030\005 \001(\0132\033.tensorflow.Ses" +
+      "sionMetadata\032\255\001\n\016FunctionGraphs\022.\n\020parti" +
+      "tion_graphs\030\001 \003(\0132\024.tensorflow.GraphDef\022" +
+      "4\n\026pre_optimization_graph\030\002 \001(\0132\024.tensor" +
+      "flow.GraphDef\0225\n\027post_optimization_graph" +
+      "\030\003 \001(\0132\024.tensorflow.GraphDef\":\n\020TensorCo" +
+      "nnection\022\023\n\013from_tensor\030\001 \001(\t\022\021\n\tto_tens" +
+      "or\030\002 \001(\t\"\260\003\n\017CallableOptions\022\014\n\004feed\030\001 \003" +
+      "(\t\022\r\n\005fetch\030\002 \003(\t\022\016\n\006target\030\003 \003(\t\022+\n\013run" +
+      "_options\030\004 \001(\0132\026.tensorflow.RunOptions\0227" +
+      "\n\021tensor_connection\030\005 \003(\0132\034.tensorflow.T" +
+      "ensorConnection\022B\n\014feed_devices\030\006 \003(\0132,." +
+      "tensorflow.CallableOptions.FeedDevicesEn" +
+      "try\022D\n\rfetch_devices\030\007 \003(\0132-.tensorflow." +
+      "CallableOptions.FetchDevicesEntry\022\027\n\017fet" +
+      "ch_skip_sync\030\010 \001(\010\0322\n\020FeedDevicesEntry\022\013" +
+      "\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t:\0028\001\0323\n\021FetchD" +
+      "evicesEntry\022\013\n\003key\030\001 \001(\t\022\r\n\005value\030\002 \001(\t:" +
+      "\0028\001B\200\001\n\024org.tensorflow.protoB\014ConfigProt" +
+      "osP\001ZUgithub.com/tensorflow/tensorflow/t" +
+      "ensorflow/go/core/protobuf/for_core_prot" +
+      "os_go_proto\370\001\001b\006proto3"
     };
     descriptor = com.google.protobuf.Descriptors.FileDescriptor
       .internalBuildGeneratedFileFrom(descriptorData,
@@ -339,7 +342,7 @@ public static void registerAllExtensions(
     internal_static_tensorflow_ConfigProto_Experimental_fieldAccessorTable = new
       com.google.protobuf.GeneratedMessageV3.FieldAccessorTable(
         internal_static_tensorflow_ConfigProto_Experimental_descriptor,
-        new java.lang.String[] { "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession", "SessionMetadata", "OptimizeForStaticGraph", "EnableMlirBridge", "MlirBridgeRollout", "EnableMlirGraphOptimization", "DisableOutputPartitionGraphs", "XlaFusionAutotunerThresh", "UseTfrt", "DisableFunctionalOpsLowering", "XlaPreferSingleGraphCluster", "CoordinationConfig", "DisableOptimizeForStaticGraph", "DisableEagerExecutorStreamingEnqueue", });
+        new java.lang.String[] { "CollectiveGroupLeader", "ExecutorType", "RecvBufMaxChunk", "UseNumaAffinity", "CollectiveDeterministicSequentialExecution", "CollectiveNccl", "ShareSessionStateInClusterspecPropagation", "DisableThreadSpinning", "ShareClusterDevicesInSession", "SessionMetadata", "OptimizeForStaticGraph", "EnableMlirBridge", "MlirBridgeRollout", "EnableMlirGraphOptimization", "DisableOutputPartitionGraphs", "XlaFusionAutotunerThresh", "UseTfrt", "EnableMultiHost", "BackendServerPort", "TargetTpu", "TargetGpu", "StreamMergeThreshold", "DisableFunctionalOpsLowering", "XlaPreferSingleGraphCluster", "CoordinationConfig", "DisableOptimizeForStaticGraph", "DisableEagerExecutorStreamingEnqueue", });
     internal_static_tensorflow_RunOptions_descriptor =
       getDescriptor().getMessageTypes().get(6);
     internal_static_tensorflow_RunOptions_fieldAccessorTable = new
diff --git a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/DatasetOptions.java b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/DatasetOptions.java
index 4e244df2f12..c7d43294b14 100644
--- a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/DatasetOptions.java
+++ b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/DatasetOptions.java
@@ -2954,6 +2954,17 @@ public interface OptimizationOptionsOrBuilder extends
      */
     boolean getInjectPrefetch();
 
+    /**
+     * bool seq_interleave_prefetch = 21;
+     * @return Whether the seqInterleavePrefetch field is set.
+     */
+    boolean hasSeqInterleavePrefetch();
+    /**
+     * bool seq_interleave_prefetch = 21;
+     * @return The seqInterleavePrefetch.
+     */
+    boolean getSeqInterleavePrefetch();
+
     public org.tensorflow.proto.data.DatasetOptions.OptimizationOptions.OptionalApplyDefaultOptimizationsCase getOptionalApplyDefaultOptimizationsCase();
 
     public org.tensorflow.proto.data.DatasetOptions.OptimizationOptions.OptionalFilterFusionCase getOptionalFilterFusionCase();
@@ -2975,10 +2986,12 @@ public interface OptimizationOptionsOrBuilder extends
     public org.tensorflow.proto.data.DatasetOptions.OptimizationOptions.OptionalFilterParallelizationCase getOptionalFilterParallelizationCase();
 
     public org.tensorflow.proto.data.DatasetOptions.OptimizationOptions.OptionalInjectPrefetchCase getOptionalInjectPrefetchCase();
+
+    public org.tensorflow.proto.data.DatasetOptions.OptimizationOptions.OptionalSeqInterleavePrefetchCase getOptionalSeqInterleavePrefetchCase();
   }
   /**
    * 
-   * next: 21
+   * next: 22
    * 
* * Protobuf type {@code tensorflow.data.OptimizationOptions} @@ -3449,6 +3462,45 @@ public int getNumber() { optionalInjectPrefetchCase_); } + private int optionalSeqInterleavePrefetchCase_ = 0; + private java.lang.Object optionalSeqInterleavePrefetch_; + public enum OptionalSeqInterleavePrefetchCase + implements com.google.protobuf.Internal.EnumLite, + com.google.protobuf.AbstractMessage.InternalOneOfEnum { + SEQ_INTERLEAVE_PREFETCH(21), + OPTIONALSEQINTERLEAVEPREFETCH_NOT_SET(0); + private final int value; + private OptionalSeqInterleavePrefetchCase(int value) { + this.value = value; + } + /** + * @param value The number of the enum to look for. + * @return The enum associated with the given number. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static OptionalSeqInterleavePrefetchCase valueOf(int value) { + return forNumber(value); + } + + public static OptionalSeqInterleavePrefetchCase forNumber(int value) { + switch (value) { + case 21: return SEQ_INTERLEAVE_PREFETCH; + case 0: return OPTIONALSEQINTERLEAVEPREFETCH_NOT_SET; + default: return null; + } + } + public int getNumber() { + return this.value; + } + }; + + public OptionalSeqInterleavePrefetchCase + getOptionalSeqInterleavePrefetchCase() { + return OptionalSeqInterleavePrefetchCase.forNumber( + optionalSeqInterleavePrefetchCase_); + } + public static final int APPLY_DEFAULT_OPTIMIZATIONS_FIELD_NUMBER = 1; /** * bool apply_default_optimizations = 1; @@ -3680,6 +3732,27 @@ public boolean getInjectPrefetch() { return false; } + public static final int SEQ_INTERLEAVE_PREFETCH_FIELD_NUMBER = 21; + /** + * bool seq_interleave_prefetch = 21; + * @return Whether the seqInterleavePrefetch field is set. + */ + @java.lang.Override + public boolean hasSeqInterleavePrefetch() { + return optionalSeqInterleavePrefetchCase_ == 21; + } + /** + * bool seq_interleave_prefetch = 21; + * @return The seqInterleavePrefetch. + */ + @java.lang.Override + public boolean getSeqInterleavePrefetch() { + if (optionalSeqInterleavePrefetchCase_ == 21) { + return (java.lang.Boolean) optionalSeqInterleavePrefetch_; + } + return false; + } + private byte memoizedIsInitialized = -1; @java.lang.Override public final boolean isInitialized() { @@ -3738,6 +3811,10 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) output.writeBool( 19, (boolean)((java.lang.Boolean) optionalInjectPrefetch_)); } + if (optionalSeqInterleavePrefetchCase_ == 21) { + output.writeBool( + 21, (boolean)((java.lang.Boolean) optionalSeqInterleavePrefetch_)); + } getUnknownFields().writeTo(output); } @@ -3802,6 +3879,11 @@ public int getSerializedSize() { .computeBoolSize( 19, (boolean)((java.lang.Boolean) optionalInjectPrefetch_)); } + if (optionalSeqInterleavePrefetchCase_ == 21) { + size += com.google.protobuf.CodedOutputStream + .computeBoolSize( + 21, (boolean)((java.lang.Boolean) optionalSeqInterleavePrefetch_)); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -3916,6 +3998,15 @@ public boolean equals(final java.lang.Object obj) { case 0: default: } + if (!getOptionalSeqInterleavePrefetchCase().equals(other.getOptionalSeqInterleavePrefetchCase())) return false; + switch (optionalSeqInterleavePrefetchCase_) { + case 21: + if (getSeqInterleavePrefetch() + != other.getSeqInterleavePrefetch()) return false; + break; + case 0: + default: + } if (!getUnknownFields().equals(other.getUnknownFields())) return false; return true; } @@ -4026,6 +4117,15 @@ public int hashCode() { case 0: default: } + switch (optionalSeqInterleavePrefetchCase_) { + case 21: + hash = (37 * hash) + SEQ_INTERLEAVE_PREFETCH_FIELD_NUMBER; + hash = (53 * hash) + com.google.protobuf.Internal.hashBoolean( + getSeqInterleavePrefetch()); + break; + case 0: + default: + } hash = (29 * hash) + getUnknownFields().hashCode(); memoizedHashCode = hash; return hash; @@ -4123,7 +4223,7 @@ protected Builder newBuilderForType( } /** *
-     * next: 21
+     * next: 22
      * 
* * Protobuf type {@code tensorflow.data.OptimizationOptions} @@ -4180,6 +4280,8 @@ public Builder clear() { optionalFilterParallelization_ = null; optionalInjectPrefetchCase_ = 0; optionalInjectPrefetch_ = null; + optionalSeqInterleavePrefetchCase_ = 0; + optionalSeqInterleavePrefetch_ = null; return this; } @@ -4239,6 +4341,9 @@ public org.tensorflow.proto.data.DatasetOptions.OptimizationOptions buildPartial if (optionalInjectPrefetchCase_ == 19) { result.optionalInjectPrefetch_ = optionalInjectPrefetch_; } + if (optionalSeqInterleavePrefetchCase_ == 21) { + result.optionalSeqInterleavePrefetch_ = optionalSeqInterleavePrefetch_; + } result.optionalApplyDefaultOptimizationsCase_ = optionalApplyDefaultOptimizationsCase_; result.optionalFilterFusionCase_ = optionalFilterFusionCase_; result.optionalMapAndBatchFusionCase_ = optionalMapAndBatchFusionCase_; @@ -4250,6 +4355,7 @@ public org.tensorflow.proto.data.DatasetOptions.OptimizationOptions buildPartial result.optionalShuffleAndRepeatFusionCase_ = optionalShuffleAndRepeatFusionCase_; result.optionalFilterParallelizationCase_ = optionalFilterParallelizationCase_; result.optionalInjectPrefetchCase_ = optionalInjectPrefetchCase_; + result.optionalSeqInterleavePrefetchCase_ = optionalSeqInterleavePrefetchCase_; onBuilt(); return result; } @@ -4397,6 +4503,15 @@ public Builder mergeFrom(org.tensorflow.proto.data.DatasetOptions.OptimizationOp break; } } + switch (other.getOptionalSeqInterleavePrefetchCase()) { + case SEQ_INTERLEAVE_PREFETCH: { + setSeqInterleavePrefetch(other.getSeqInterleavePrefetch()); + break; + } + case OPTIONALSEQINTERLEAVEPREFETCH_NOT_SET: { + break; + } + } this.mergeUnknownFields(other.getUnknownFields()); onChanged(); return this; @@ -4478,6 +4593,11 @@ public Builder mergeFrom( optionalInjectPrefetchCase_ = 19; break; } // case 152 + case 168: { + optionalSeqInterleavePrefetch_ = input.readBool(); + optionalSeqInterleavePrefetchCase_ = 21; + break; + } // case 168 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -4658,6 +4778,21 @@ public Builder clearOptionalInjectPrefetch() { return this; } + private int optionalSeqInterleavePrefetchCase_ = 0; + private java.lang.Object optionalSeqInterleavePrefetch_; + public OptionalSeqInterleavePrefetchCase + getOptionalSeqInterleavePrefetchCase() { + return OptionalSeqInterleavePrefetchCase.forNumber( + optionalSeqInterleavePrefetchCase_); + } + + public Builder clearOptionalSeqInterleavePrefetch() { + optionalSeqInterleavePrefetchCase_ = 0; + optionalSeqInterleavePrefetch_ = null; + onChanged(); + return this; + } + /** * bool apply_default_optimizations = 1; @@ -5109,6 +5244,47 @@ public Builder clearInjectPrefetch() { } return this; } + + /** + * bool seq_interleave_prefetch = 21; + * @return Whether the seqInterleavePrefetch field is set. + */ + public boolean hasSeqInterleavePrefetch() { + return optionalSeqInterleavePrefetchCase_ == 21; + } + /** + * bool seq_interleave_prefetch = 21; + * @return The seqInterleavePrefetch. + */ + public boolean getSeqInterleavePrefetch() { + if (optionalSeqInterleavePrefetchCase_ == 21) { + return (java.lang.Boolean) optionalSeqInterleavePrefetch_; + } + return false; + } + /** + * bool seq_interleave_prefetch = 21; + * @param value The seqInterleavePrefetch to set. + * @return This builder for chaining. + */ + public Builder setSeqInterleavePrefetch(boolean value) { + optionalSeqInterleavePrefetchCase_ = 21; + optionalSeqInterleavePrefetch_ = value; + onChanged(); + return this; + } + /** + * bool seq_interleave_prefetch = 21; + * @return This builder for chaining. + */ + public Builder clearSeqInterleavePrefetch() { + if (optionalSeqInterleavePrefetchCase_ == 21) { + optionalSeqInterleavePrefetchCase_ = 0; + optionalSeqInterleavePrefetch_ = null; + onChanged(); + } + return this; + } @java.lang.Override public final Builder setUnknownFields( final com.google.protobuf.UnknownFieldSet unknownFields) { @@ -5935,6 +6111,64 @@ public interface OptionsOrBuilder extends // @@protoc_insertion_point(interface_extends:tensorflow.data.Options) com.google.protobuf.MessageOrBuilder { + /** + * string dataset_name = 10; + * @return Whether the datasetName field is set. + */ + boolean hasDatasetName(); + /** + * string dataset_name = 10; + * @return The datasetName. + */ + java.lang.String getDatasetName(); + /** + * string dataset_name = 10; + * @return The bytes for datasetName. + */ + com.google.protobuf.ByteString + getDatasetNameBytes(); + + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @return A list containing the frameworkType. + */ + java.util.List + getFrameworkTypeList(); + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @return The count of frameworkType. + */ + int getFrameworkTypeCount(); + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @param index The index of the element to return. + * @return The frameworkType at the given index. + */ + java.lang.String getFrameworkType(int index); + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @param index The index of the value to return. + * @return The bytes of the frameworkType at the given index. + */ + com.google.protobuf.ByteString + getFrameworkTypeBytes(int index); + /** * bool deterministic = 1; * @return Whether the deterministic field is set. @@ -6103,6 +6337,8 @@ public interface OptionsOrBuilder extends */ boolean getWarmStart(); + public org.tensorflow.proto.data.DatasetOptions.Options.OptionalDatasetNameCase getOptionalDatasetNameCase(); + public org.tensorflow.proto.data.DatasetOptions.Options.OptionalDeterministicCase getOptionalDeterministicCase(); public org.tensorflow.proto.data.DatasetOptions.Options.OptionalSlackCase getOptionalSlackCase(); @@ -6117,7 +6353,7 @@ public interface OptionsOrBuilder extends *
    * Message stored with Dataset objects to control how datasets are processed and
    * optimized.
-   * next: 10
+   * next: 12
    * 
* * Protobuf type {@code tensorflow.data.Options} @@ -6132,6 +6368,7 @@ private Options(com.google.protobuf.GeneratedMessageV3.Builder builder) { super(builder); } private Options() { + frameworkType_ = com.google.protobuf.LazyStringArrayList.EMPTY; } @java.lang.Override @@ -6159,6 +6396,45 @@ protected java.lang.Object newInstance( org.tensorflow.proto.data.DatasetOptions.Options.class, org.tensorflow.proto.data.DatasetOptions.Options.Builder.class); } + private int optionalDatasetNameCase_ = 0; + private java.lang.Object optionalDatasetName_; + public enum OptionalDatasetNameCase + implements com.google.protobuf.Internal.EnumLite, + com.google.protobuf.AbstractMessage.InternalOneOfEnum { + DATASET_NAME(10), + OPTIONALDATASETNAME_NOT_SET(0); + private final int value; + private OptionalDatasetNameCase(int value) { + this.value = value; + } + /** + * @param value The number of the enum to look for. + * @return The enum associated with the given number. + * @deprecated Use {@link #forNumber(int)} instead. + */ + @java.lang.Deprecated + public static OptionalDatasetNameCase valueOf(int value) { + return forNumber(value); + } + + public static OptionalDatasetNameCase forNumber(int value) { + switch (value) { + case 10: return DATASET_NAME; + case 0: return OPTIONALDATASETNAME_NOT_SET; + default: return null; + } + } + public int getNumber() { + return this.value; + } + }; + + public OptionalDatasetNameCase + getOptionalDatasetNameCase() { + return OptionalDatasetNameCase.forNumber( + optionalDatasetNameCase_); + } + private int optionalDeterministicCase_ = 0; private java.lang.Object optionalDeterministic_; public enum OptionalDeterministicCase @@ -6354,6 +6630,109 @@ public int getNumber() { optionalWarmStartCase_); } + public static final int DATASET_NAME_FIELD_NUMBER = 10; + /** + * string dataset_name = 10; + * @return Whether the datasetName field is set. + */ + public boolean hasDatasetName() { + return optionalDatasetNameCase_ == 10; + } + /** + * string dataset_name = 10; + * @return The datasetName. + */ + public java.lang.String getDatasetName() { + java.lang.Object ref = ""; + if (optionalDatasetNameCase_ == 10) { + ref = optionalDatasetName_; + } + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (optionalDatasetNameCase_ == 10) { + optionalDatasetName_ = s; + } + return s; + } + } + /** + * string dataset_name = 10; + * @return The bytes for datasetName. + */ + public com.google.protobuf.ByteString + getDatasetNameBytes() { + java.lang.Object ref = ""; + if (optionalDatasetNameCase_ == 10) { + ref = optionalDatasetName_; + } + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + if (optionalDatasetNameCase_ == 10) { + optionalDatasetName_ = b; + } + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + + public static final int FRAMEWORK_TYPE_FIELD_NUMBER = 11; + private com.google.protobuf.LazyStringList frameworkType_; + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @return A list containing the frameworkType. + */ + public com.google.protobuf.ProtocolStringList + getFrameworkTypeList() { + return frameworkType_; + } + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @return The count of frameworkType. + */ + public int getFrameworkTypeCount() { + return frameworkType_.size(); + } + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @param index The index of the element to return. + * @return The frameworkType at the given index. + */ + public java.lang.String getFrameworkType(int index) { + return frameworkType_.get(index); + } + /** + *
+     * List of frameworks used to generate this dataset.
+     * 
+ * + * repeated string framework_type = 11; + * @param index The index of the value to return. + * @return The bytes of the frameworkType at the given index. + */ + public com.google.protobuf.ByteString + getFrameworkTypeBytes(int index) { + return frameworkType_.getByteString(index); + } + public static final int DETERMINISTIC_FIELD_NUMBER = 1; /** * bool deterministic = 1; @@ -6667,6 +7046,12 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) output.writeBool( 9, (boolean)((java.lang.Boolean) optionalWarmStart_)); } + if (optionalDatasetNameCase_ == 10) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 10, optionalDatasetName_); + } + for (int i = 0; i < frameworkType_.size(); i++) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 11, frameworkType_.getRaw(i)); + } getUnknownFields().writeTo(output); } @@ -6716,6 +7101,17 @@ public int getSerializedSize() { .computeBoolSize( 9, (boolean)((java.lang.Boolean) optionalWarmStart_)); } + if (optionalDatasetNameCase_ == 10) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(10, optionalDatasetName_); + } + { + int dataSize = 0; + for (int i = 0; i < frameworkType_.size(); i++) { + dataSize += computeStringSizeNoTag(frameworkType_.getRaw(i)); + } + size += dataSize; + size += 1 * getFrameworkTypeList().size(); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -6731,6 +7127,8 @@ public boolean equals(final java.lang.Object obj) { } org.tensorflow.proto.data.DatasetOptions.Options other = (org.tensorflow.proto.data.DatasetOptions.Options) obj; + if (!getFrameworkTypeList() + .equals(other.getFrameworkTypeList())) return false; if (hasAutotuneOptions() != other.hasAutotuneOptions()) return false; if (hasAutotuneOptions()) { if (!getAutotuneOptions() @@ -6751,6 +7149,15 @@ public boolean equals(final java.lang.Object obj) { if (!getThreadingOptions() .equals(other.getThreadingOptions())) return false; } + if (!getOptionalDatasetNameCase().equals(other.getOptionalDatasetNameCase())) return false; + switch (optionalDatasetNameCase_) { + case 10: + if (!getDatasetName() + .equals(other.getDatasetName())) return false; + break; + case 0: + default: + } if (!getOptionalDeterministicCase().equals(other.getOptionalDeterministicCase())) return false; switch (optionalDeterministicCase_) { case 1: @@ -6807,6 +7214,10 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + if (getFrameworkTypeCount() > 0) { + hash = (37 * hash) + FRAMEWORK_TYPE_FIELD_NUMBER; + hash = (53 * hash) + getFrameworkTypeList().hashCode(); + } if (hasAutotuneOptions()) { hash = (37 * hash) + AUTOTUNE_OPTIONS_FIELD_NUMBER; hash = (53 * hash) + getAutotuneOptions().hashCode(); @@ -6823,6 +7234,14 @@ public int hashCode() { hash = (37 * hash) + THREADING_OPTIONS_FIELD_NUMBER; hash = (53 * hash) + getThreadingOptions().hashCode(); } + switch (optionalDatasetNameCase_) { + case 10: + hash = (37 * hash) + DATASET_NAME_FIELD_NUMBER; + hash = (53 * hash) + getDatasetName().hashCode(); + break; + case 0: + default: + } switch (optionalDeterministicCase_) { case 1: hash = (37 * hash) + DETERMINISTIC_FIELD_NUMBER; @@ -6966,7 +7385,7 @@ protected Builder newBuilderForType( *
      * Message stored with Dataset objects to control how datasets are processed and
      * optimized.
-     * next: 10
+     * next: 12
      * 
* * Protobuf type {@code tensorflow.data.Options} @@ -7001,6 +7420,8 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + frameworkType_ = com.google.protobuf.LazyStringArrayList.EMPTY; + bitField0_ = (bitField0_ & ~0x00000001); if (autotuneOptionsBuilder_ == null) { autotuneOptions_ = null; } else { @@ -7025,6 +7446,8 @@ public Builder clear() { threadingOptions_ = null; threadingOptionsBuilder_ = null; } + optionalDatasetNameCase_ = 0; + optionalDatasetName_ = null; optionalDeterministicCase_ = 0; optionalDeterministic_ = null; optionalSlackCase_ = 0; @@ -7061,6 +7484,15 @@ public org.tensorflow.proto.data.DatasetOptions.Options build() { @java.lang.Override public org.tensorflow.proto.data.DatasetOptions.Options buildPartial() { org.tensorflow.proto.data.DatasetOptions.Options result = new org.tensorflow.proto.data.DatasetOptions.Options(this); + int from_bitField0_ = bitField0_; + if (optionalDatasetNameCase_ == 10) { + result.optionalDatasetName_ = optionalDatasetName_; + } + if (((bitField0_ & 0x00000001) != 0)) { + frameworkType_ = frameworkType_.getUnmodifiableView(); + bitField0_ = (bitField0_ & ~0x00000001); + } + result.frameworkType_ = frameworkType_; if (optionalDeterministicCase_ == 1) { result.optionalDeterministic_ = optionalDeterministic_; } @@ -7096,6 +7528,7 @@ public org.tensorflow.proto.data.DatasetOptions.Options buildPartial() { if (optionalWarmStartCase_ == 9) { result.optionalWarmStart_ = optionalWarmStart_; } + result.optionalDatasetNameCase_ = optionalDatasetNameCase_; result.optionalDeterministicCase_ = optionalDeterministicCase_; result.optionalSlackCase_ = optionalSlackCase_; result.optionalExternalStatePolicyCase_ = optionalExternalStatePolicyCase_; @@ -7149,6 +7582,16 @@ public Builder mergeFrom(com.google.protobuf.Message other) { public Builder mergeFrom(org.tensorflow.proto.data.DatasetOptions.Options other) { if (other == org.tensorflow.proto.data.DatasetOptions.Options.getDefaultInstance()) return this; + if (!other.frameworkType_.isEmpty()) { + if (frameworkType_.isEmpty()) { + frameworkType_ = other.frameworkType_; + bitField0_ = (bitField0_ & ~0x00000001); + } else { + ensureFrameworkTypeIsMutable(); + frameworkType_.addAll(other.frameworkType_); + } + onChanged(); + } if (other.hasAutotuneOptions()) { mergeAutotuneOptions(other.getAutotuneOptions()); } @@ -7161,6 +7604,17 @@ public Builder mergeFrom(org.tensorflow.proto.data.DatasetOptions.Options other) if (other.hasThreadingOptions()) { mergeThreadingOptions(other.getThreadingOptions()); } + switch (other.getOptionalDatasetNameCase()) { + case DATASET_NAME: { + optionalDatasetNameCase_ = 10; + optionalDatasetName_ = other.optionalDatasetName_; + onChanged(); + break; + } + case OPTIONALDATASETNAME_NOT_SET: { + break; + } + } switch (other.getOptionalDeterministicCase()) { case DETERMINISTIC: { setDeterministic(other.getDeterministic()); @@ -7286,6 +7740,18 @@ public Builder mergeFrom( optionalWarmStartCase_ = 9; break; } // case 72 + case 82: { + java.lang.String s = input.readStringRequireUtf8(); + optionalDatasetNameCase_ = 10; + optionalDatasetName_ = s; + break; + } // case 82 + case 90: { + java.lang.String s = input.readStringRequireUtf8(); + ensureFrameworkTypeIsMutable(); + frameworkType_.add(s); + break; + } // case 90 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -7301,6 +7767,21 @@ public Builder mergeFrom( } // finally return this; } + private int optionalDatasetNameCase_ = 0; + private java.lang.Object optionalDatasetName_; + public OptionalDatasetNameCase + getOptionalDatasetNameCase() { + return OptionalDatasetNameCase.forNumber( + optionalDatasetNameCase_); + } + + public Builder clearOptionalDatasetName() { + optionalDatasetNameCase_ = 0; + optionalDatasetName_ = null; + onChanged(); + return this; + } + private int optionalDeterministicCase_ = 0; private java.lang.Object optionalDeterministic_; public OptionalDeterministicCase @@ -7376,6 +7857,250 @@ public Builder clearOptionalWarmStart() { return this; } + private int bitField0_; + + /** + * string dataset_name = 10; + * @return Whether the datasetName field is set. + */ + @java.lang.Override + public boolean hasDatasetName() { + return optionalDatasetNameCase_ == 10; + } + /** + * string dataset_name = 10; + * @return The datasetName. + */ + @java.lang.Override + public java.lang.String getDatasetName() { + java.lang.Object ref = ""; + if (optionalDatasetNameCase_ == 10) { + ref = optionalDatasetName_; + } + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + if (optionalDatasetNameCase_ == 10) { + optionalDatasetName_ = s; + } + return s; + } else { + return (java.lang.String) ref; + } + } + /** + * string dataset_name = 10; + * @return The bytes for datasetName. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getDatasetNameBytes() { + java.lang.Object ref = ""; + if (optionalDatasetNameCase_ == 10) { + ref = optionalDatasetName_; + } + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + if (optionalDatasetNameCase_ == 10) { + optionalDatasetName_ = b; + } + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + * string dataset_name = 10; + * @param value The datasetName to set. + * @return This builder for chaining. + */ + public Builder setDatasetName( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + optionalDatasetNameCase_ = 10; + optionalDatasetName_ = value; + onChanged(); + return this; + } + /** + * string dataset_name = 10; + * @return This builder for chaining. + */ + public Builder clearDatasetName() { + if (optionalDatasetNameCase_ == 10) { + optionalDatasetNameCase_ = 0; + optionalDatasetName_ = null; + onChanged(); + } + return this; + } + /** + * string dataset_name = 10; + * @param value The bytes for datasetName to set. + * @return This builder for chaining. + */ + public Builder setDatasetNameBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + optionalDatasetNameCase_ = 10; + optionalDatasetName_ = value; + onChanged(); + return this; + } + + private com.google.protobuf.LazyStringList frameworkType_ = com.google.protobuf.LazyStringArrayList.EMPTY; + private void ensureFrameworkTypeIsMutable() { + if (!((bitField0_ & 0x00000001) != 0)) { + frameworkType_ = new com.google.protobuf.LazyStringArrayList(frameworkType_); + bitField0_ |= 0x00000001; + } + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @return A list containing the frameworkType. + */ + public com.google.protobuf.ProtocolStringList + getFrameworkTypeList() { + return frameworkType_.getUnmodifiableView(); + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @return The count of frameworkType. + */ + public int getFrameworkTypeCount() { + return frameworkType_.size(); + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @param index The index of the element to return. + * @return The frameworkType at the given index. + */ + public java.lang.String getFrameworkType(int index) { + return frameworkType_.get(index); + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @param index The index of the value to return. + * @return The bytes of the frameworkType at the given index. + */ + public com.google.protobuf.ByteString + getFrameworkTypeBytes(int index) { + return frameworkType_.getByteString(index); + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @param index The index to set the value at. + * @param value The frameworkType to set. + * @return This builder for chaining. + */ + public Builder setFrameworkType( + int index, java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + ensureFrameworkTypeIsMutable(); + frameworkType_.set(index, value); + onChanged(); + return this; + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @param value The frameworkType to add. + * @return This builder for chaining. + */ + public Builder addFrameworkType( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + ensureFrameworkTypeIsMutable(); + frameworkType_.add(value); + onChanged(); + return this; + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @param values The frameworkType to add. + * @return This builder for chaining. + */ + public Builder addAllFrameworkType( + java.lang.Iterable values) { + ensureFrameworkTypeIsMutable(); + com.google.protobuf.AbstractMessageLite.Builder.addAll( + values, frameworkType_); + onChanged(); + return this; + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @return This builder for chaining. + */ + public Builder clearFrameworkType() { + frameworkType_ = com.google.protobuf.LazyStringArrayList.EMPTY; + bitField0_ = (bitField0_ & ~0x00000001); + onChanged(); + return this; + } + /** + *
+       * List of frameworks used to generate this dataset.
+       * 
+ * + * repeated string framework_type = 11; + * @param value The bytes of the frameworkType to add. + * @return This builder for chaining. + */ + public Builder addFrameworkTypeBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + ensureFrameworkTypeIsMutable(); + frameworkType_.add(value); + onChanged(); + return this; + } /** * bool deterministic = 1; @@ -8351,7 +9076,7 @@ public org.tensorflow.proto.data.DatasetOptions.Options getDefaultInstanceForTyp "ODERATE\020\002\"\177\n\021DistributeOptions\022;\n\021auto_s" + "hard_policy\030\001 \001(\0162 .tensorflow.data.Auto" + "ShardPolicy\022\025\n\013num_devices\030\002 \001(\005H\000B\026\n\024op" + - "tional_num_devices\"\362\005\n\023OptimizationOptio" + + "tional_num_devices\"\271\006\n\023OptimizationOptio" + "ns\022%\n\033apply_default_optimizations\030\001 \001(\010H" + "\000\022\027\n\rfilter_fusion\030\006 \001(\010H\001\022\036\n\024map_and_ba" + "tch_fusion\030\t \001(\010H\002\022\037\n\025map_and_filter_fus" + @@ -8360,42 +9085,46 @@ public org.tensorflow.proto.data.DatasetOptions.Options getDefaultInstanceForTyp "tion\030\016 \001(\010H\006\022\030\n\016parallel_batch\030\017 \001(\010H\007\022#" + "\n\031shuffle_and_repeat_fusion\030\021 \001(\010H\010\022 \n\026f" + "ilter_parallelization\030\022 \001(\010H\t\022\031\n\017inject_" + - "prefetch\030\023 \001(\010H\nB&\n$optional_apply_defau" + - "lt_optimizationsB\030\n\026optional_filter_fusi" + - "onB\037\n\035optional_map_and_batch_fusionB \n\036o" + - "ptional_map_and_filter_fusionB\025\n\023optiona" + - "l_map_fusionB\036\n\034optional_map_paralleliza" + - "tionB\033\n\031optional_noop_eliminationB\031\n\027opt" + - "ional_parallel_batchB$\n\"optional_shuffle" + - "_and_repeat_fusionB!\n\037optional_filter_pa" + - "rallelizationB\032\n\030optional_inject_prefetc" + - "hJ\004\010\002\020\003J\004\010\003\020\004J\004\010\004\020\005J\004\010\005\020\006J\004\010\007\020\010J\004\010\010\020\tJ\004\010" + - "\r\020\016J\004\010\020\020\021J\004\010\024\020\025\"\242\001\n\020ThreadingOptions\022\"\n\030" + - "max_intra_op_parallelism\030\001 \001(\005H\000\022!\n\027priv" + - "ate_threadpool_size\030\002 \001(\005H\001B#\n!optional_" + - "max_intra_op_parallelismB\"\n optional_pri" + - "vate_threadpool_size\"\262\004\n\007Options\022\027\n\rdete" + - "rministic\030\001 \001(\010H\000\022:\n\020autotune_options\030\007 " + - "\001(\0132 .tensorflow.data.AutotuneOptions\022>\n" + - "\022distribute_options\030\002 \001(\0132\".tensorflow.d" + - "ata.DistributeOptions\022B\n\024optimization_op" + - "tions\030\003 \001(\0132$.tensorflow.data.Optimizati" + - "onOptions\022\017\n\005slack\030\004 \001(\010H\001\022<\n\021threading_" + - "options\030\005 \001(\0132!.tensorflow.data.Threadin" + - "gOptions\022E\n\025external_state_policy\030\006 \001(\0162" + - "$.tensorflow.data.ExternalStatePolicyH\002\022" + - "\035\n\023symbolic_checkpoint\030\010 \001(\010H\003\022\024\n\nwarm_s" + - "tart\030\t \001(\010H\004B\030\n\026optional_deterministicB\020" + - "\n\016optional_slackB \n\036optional_external_st" + - "ate_policyB\036\n\034optional_symbolic_checkpoi" + - "ntB\025\n\023optional_warm_start*K\n\017AutoShardPo" + - "licy\022\010\n\004AUTO\020\000\022\010\n\004FILE\020\001\022\010\n\004DATA\020\002\022\010\n\004HI" + - "NT\020\003\022\020\n\003OFF\020\377\377\377\377\377\377\377\377\377\001*J\n\023ExternalStateP" + - "olicy\022\017\n\013POLICY_WARN\020\000\022\021\n\rPOLICY_IGNORE\020" + - "\001\022\017\n\013POLICY_FAIL\020\002Bs\n\031org.tensorflow.pro" + - "to.dataZVgithub.com/tensorflow/tensorflo" + - "w/tensorflow/go/core/framework/dataset_o" + - "ptions_go_protob\006proto3" + "prefetch\030\023 \001(\010H\n\022!\n\027seq_interleave_prefe" + + "tch\030\025 \001(\010H\013B&\n$optional_apply_default_op" + + "timizationsB\030\n\026optional_filter_fusionB\037\n" + + "\035optional_map_and_batch_fusionB \n\036option" + + "al_map_and_filter_fusionB\025\n\023optional_map" + + "_fusionB\036\n\034optional_map_parallelizationB" + + "\033\n\031optional_noop_eliminationB\031\n\027optional" + + "_parallel_batchB$\n\"optional_shuffle_and_" + + "repeat_fusionB!\n\037optional_filter_paralle" + + "lizationB\032\n\030optional_inject_prefetchB\"\n " + + "optional_seq_interleave_prefetchJ\004\010\002\020\003J\004" + + "\010\003\020\004J\004\010\004\020\005J\004\010\005\020\006J\004\010\007\020\010J\004\010\010\020\tJ\004\010\r\020\016J\004\010\020\020\021" + + "J\004\010\024\020\025\"\242\001\n\020ThreadingOptions\022\"\n\030max_intra" + + "_op_parallelism\030\001 \001(\005H\000\022!\n\027private_threa" + + "dpool_size\030\002 \001(\005H\001B#\n!optional_max_intra" + + "_op_parallelismB\"\n optional_private_thre" + + "adpool_size\"\373\004\n\007Options\022\026\n\014dataset_name\030" + + "\n \001(\tH\000\022\026\n\016framework_type\030\013 \003(\t\022\027\n\rdeter" + + "ministic\030\001 \001(\010H\001\022:\n\020autotune_options\030\007 \001" + + "(\0132 .tensorflow.data.AutotuneOptions\022>\n\022" + + "distribute_options\030\002 \001(\0132\".tensorflow.da" + + "ta.DistributeOptions\022B\n\024optimization_opt" + + "ions\030\003 \001(\0132$.tensorflow.data.Optimizatio" + + "nOptions\022\017\n\005slack\030\004 \001(\010H\002\022<\n\021threading_o" + + "ptions\030\005 \001(\0132!.tensorflow.data.Threading" + + "Options\022E\n\025external_state_policy\030\006 \001(\0162$" + + ".tensorflow.data.ExternalStatePolicyH\003\022\035" + + "\n\023symbolic_checkpoint\030\010 \001(\010H\004\022\024\n\nwarm_st" + + "art\030\t \001(\010H\005B\027\n\025optional_dataset_nameB\030\n\026" + + "optional_deterministicB\020\n\016optional_slack" + + "B \n\036optional_external_state_policyB\036\n\034op" + + "tional_symbolic_checkpointB\025\n\023optional_w" + + "arm_start*K\n\017AutoShardPolicy\022\010\n\004AUTO\020\000\022\010" + + "\n\004FILE\020\001\022\010\n\004DATA\020\002\022\010\n\004HINT\020\003\022\020\n\003OFF\020\377\377\377\377" + + "\377\377\377\377\377\001*J\n\023ExternalStatePolicy\022\017\n\013POLICY_" + + "WARN\020\000\022\021\n\rPOLICY_IGNORE\020\001\022\017\n\013POLICY_FAIL" + + "\020\002Bs\n\031org.tensorflow.proto.dataZVgithub." + + "com/tensorflow/tensorflow/tensorflow/go/" + + "core/framework/dataset_options_go_protob" + + "\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -8425,7 +9154,7 @@ public org.tensorflow.proto.data.DatasetOptions.Options getDefaultInstanceForTyp internal_static_tensorflow_data_OptimizationOptions_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_tensorflow_data_OptimizationOptions_descriptor, - new java.lang.String[] { "ApplyDefaultOptimizations", "FilterFusion", "MapAndBatchFusion", "MapAndFilterFusion", "MapFusion", "MapParallelization", "NoopElimination", "ParallelBatch", "ShuffleAndRepeatFusion", "FilterParallelization", "InjectPrefetch", "OptionalApplyDefaultOptimizations", "OptionalFilterFusion", "OptionalMapAndBatchFusion", "OptionalMapAndFilterFusion", "OptionalMapFusion", "OptionalMapParallelization", "OptionalNoopElimination", "OptionalParallelBatch", "OptionalShuffleAndRepeatFusion", "OptionalFilterParallelization", "OptionalInjectPrefetch", }); + new java.lang.String[] { "ApplyDefaultOptimizations", "FilterFusion", "MapAndBatchFusion", "MapAndFilterFusion", "MapFusion", "MapParallelization", "NoopElimination", "ParallelBatch", "ShuffleAndRepeatFusion", "FilterParallelization", "InjectPrefetch", "SeqInterleavePrefetch", "OptionalApplyDefaultOptimizations", "OptionalFilterFusion", "OptionalMapAndBatchFusion", "OptionalMapAndFilterFusion", "OptionalMapFusion", "OptionalMapParallelization", "OptionalNoopElimination", "OptionalParallelBatch", "OptionalShuffleAndRepeatFusion", "OptionalFilterParallelization", "OptionalInjectPrefetch", "OptionalSeqInterleavePrefetch", }); internal_static_tensorflow_data_ThreadingOptions_descriptor = getDescriptor().getMessageTypes().get(4); internal_static_tensorflow_data_ThreadingOptions_fieldAccessorTable = new @@ -8437,7 +9166,7 @@ public org.tensorflow.proto.data.DatasetOptions.Options getDefaultInstanceForTyp internal_static_tensorflow_data_Options_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_tensorflow_data_Options_descriptor, - new java.lang.String[] { "Deterministic", "AutotuneOptions", "DistributeOptions", "OptimizationOptions", "Slack", "ThreadingOptions", "ExternalStatePolicy", "SymbolicCheckpoint", "WarmStart", "OptionalDeterministic", "OptionalSlack", "OptionalExternalStatePolicy", "OptionalSymbolicCheckpoint", "OptionalWarmStart", }); + new java.lang.String[] { "DatasetName", "FrameworkType", "Deterministic", "AutotuneOptions", "DistributeOptions", "OptimizationOptions", "Slack", "ThreadingOptions", "ExternalStatePolicy", "SymbolicCheckpoint", "WarmStart", "OptionalDatasetName", "OptionalDeterministic", "OptionalSlack", "OptionalExternalStatePolicy", "OptionalSymbolicCheckpoint", "OptionalWarmStart", }); org.tensorflow.proto.data.model.Model.getDescriptor(); } diff --git a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/model/Model.java b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/model/Model.java index f03d10c2626..e89b1024d4c 100644 --- a/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/model/Model.java +++ b/tensorflow-core/tensorflow-core-native/src/gen/java/org/tensorflow/proto/data/model/Model.java @@ -314,6 +314,26 @@ public interface ModelProtoOrBuilder extends // @@protoc_insertion_point(interface_extends:tensorflow.data.model.ModelProto) com.google.protobuf.MessageOrBuilder { + /** + *
+     * User-defined name for the dataset. Empty if no name was set.
+     * 
+ * + * string dataset_name = 7; + * @return The datasetName. + */ + java.lang.String getDatasetName(); + /** + *
+     * User-defined name for the dataset. Empty if no name was set.
+     * 
+ * + * string dataset_name = 7; + * @return The bytes for datasetName. + */ + com.google.protobuf.ByteString + getDatasetNameBytes(); + /** *
      * Map of node IDs to nodes of this model.
@@ -440,6 +460,7 @@ private ModelProto(com.google.protobuf.GeneratedMessageV3.Builder builder) {
       super(builder);
     }
     private ModelProto() {
+      datasetName_ = "";
       gapTimes_ = emptyLongList();
     }
 
@@ -5007,6 +5028,52 @@ public org.tensorflow.proto.data.model.Model.ModelProto.OptimizationParams getDe
 
     }
 
+    public static final int DATASET_NAME_FIELD_NUMBER = 7;
+    private volatile java.lang.Object datasetName_;
+    /**
+     * 
+     * User-defined name for the dataset. Empty if no name was set.
+     * 
+ * + * string dataset_name = 7; + * @return The datasetName. + */ + @java.lang.Override + public java.lang.String getDatasetName() { + java.lang.Object ref = datasetName_; + if (ref instanceof java.lang.String) { + return (java.lang.String) ref; + } else { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + datasetName_ = s; + return s; + } + } + /** + *
+     * User-defined name for the dataset. Empty if no name was set.
+     * 
+ * + * string dataset_name = 7; + * @return The bytes for datasetName. + */ + @java.lang.Override + public com.google.protobuf.ByteString + getDatasetNameBytes() { + java.lang.Object ref = datasetName_; + if (ref instanceof java.lang.String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + datasetName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + public static final int NODES_FIELD_NUMBER = 1; private static final class NodesDefaultEntryHolder { static final com.google.protobuf.MapEntry< @@ -5225,6 +5292,9 @@ public void writeTo(com.google.protobuf.CodedOutputStream output) for (int i = 0; i < gapTimes_.size(); i++) { output.writeUInt64NoTag(gapTimes_.getLong(i)); } + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(datasetName_)) { + com.google.protobuf.GeneratedMessageV3.writeString(output, 7, datasetName_); + } getUnknownFields().writeTo(output); } @@ -5270,6 +5340,9 @@ public int getSerializedSize() { } gapTimesMemoizedSerializedSize = dataSize; } + if (!com.google.protobuf.GeneratedMessageV3.isStringEmpty(datasetName_)) { + size += com.google.protobuf.GeneratedMessageV3.computeStringSize(7, datasetName_); + } size += getUnknownFields().getSerializedSize(); memoizedSize = size; return size; @@ -5285,6 +5358,8 @@ public boolean equals(final java.lang.Object obj) { } org.tensorflow.proto.data.model.Model.ModelProto other = (org.tensorflow.proto.data.model.Model.ModelProto) obj; + if (!getDatasetName() + .equals(other.getDatasetName())) return false; if (!internalGetNodes().equals( other.internalGetNodes())) return false; if (getOutput() @@ -5309,6 +5384,8 @@ public int hashCode() { } int hash = 41; hash = (19 * hash) + getDescriptor().hashCode(); + hash = (37 * hash) + DATASET_NAME_FIELD_NUMBER; + hash = (53 * hash) + getDatasetName().hashCode(); if (!internalGetNodes().getMap().isEmpty()) { hash = (37 * hash) + NODES_FIELD_NUMBER; hash = (53 * hash) + internalGetNodes().hashCode(); @@ -5482,6 +5559,8 @@ private Builder( @java.lang.Override public Builder clear() { super.clear(); + datasetName_ = ""; + internalGetMutableNodes().clear(); output_ = 0L; @@ -5522,6 +5601,7 @@ public org.tensorflow.proto.data.model.Model.ModelProto build() { public org.tensorflow.proto.data.model.Model.ModelProto buildPartial() { org.tensorflow.proto.data.model.Model.ModelProto result = new org.tensorflow.proto.data.model.Model.ModelProto(this); int from_bitField0_ = bitField0_; + result.datasetName_ = datasetName_; result.nodes_ = internalGetNodes(); result.nodes_.makeImmutable(); result.output_ = output_; @@ -5584,6 +5664,10 @@ public Builder mergeFrom(com.google.protobuf.Message other) { public Builder mergeFrom(org.tensorflow.proto.data.model.Model.ModelProto other) { if (other == org.tensorflow.proto.data.model.Model.ModelProto.getDefaultInstance()) return this; + if (!other.getDatasetName().isEmpty()) { + datasetName_ = other.datasetName_; + onChanged(); + } internalGetMutableNodes().mergeFrom( other.internalGetNodes()); if (other.getOutput() != 0L) { @@ -5672,6 +5756,11 @@ public Builder mergeFrom( input.popLimit(limit); break; } // case 50 + case 58: { + datasetName_ = input.readStringRequireUtf8(); + + break; + } // case 58 default: { if (!super.parseUnknownField(input, extensionRegistry, tag)) { done = true; // was an endgroup tag @@ -5689,6 +5778,102 @@ public Builder mergeFrom( } private int bitField0_; + private java.lang.Object datasetName_ = ""; + /** + *
+       * User-defined name for the dataset. Empty if no name was set.
+       * 
+ * + * string dataset_name = 7; + * @return The datasetName. + */ + public java.lang.String getDatasetName() { + java.lang.Object ref = datasetName_; + if (!(ref instanceof java.lang.String)) { + com.google.protobuf.ByteString bs = + (com.google.protobuf.ByteString) ref; + java.lang.String s = bs.toStringUtf8(); + datasetName_ = s; + return s; + } else { + return (java.lang.String) ref; + } + } + /** + *
+       * User-defined name for the dataset. Empty if no name was set.
+       * 
+ * + * string dataset_name = 7; + * @return The bytes for datasetName. + */ + public com.google.protobuf.ByteString + getDatasetNameBytes() { + java.lang.Object ref = datasetName_; + if (ref instanceof String) { + com.google.protobuf.ByteString b = + com.google.protobuf.ByteString.copyFromUtf8( + (java.lang.String) ref); + datasetName_ = b; + return b; + } else { + return (com.google.protobuf.ByteString) ref; + } + } + /** + *
+       * User-defined name for the dataset. Empty if no name was set.
+       * 
+ * + * string dataset_name = 7; + * @param value The datasetName to set. + * @return This builder for chaining. + */ + public Builder setDatasetName( + java.lang.String value) { + if (value == null) { + throw new NullPointerException(); + } + + datasetName_ = value; + onChanged(); + return this; + } + /** + *
+       * User-defined name for the dataset. Empty if no name was set.
+       * 
+ * + * string dataset_name = 7; + * @return This builder for chaining. + */ + public Builder clearDatasetName() { + + datasetName_ = getDefaultInstance().getDatasetName(); + onChanged(); + return this; + } + /** + *
+       * User-defined name for the dataset. Empty if no name was set.
+       * 
+ * + * string dataset_name = 7; + * @param value The bytes for datasetName to set. + * @return This builder for chaining. + */ + public Builder setDatasetNameBytes( + com.google.protobuf.ByteString value) { + if (value == null) { + throw new NullPointerException(); + } + checkByteStringIsUtf8(value); + + datasetName_ = value; + onChanged(); + return this; + } + private com.google.protobuf.MapField< java.lang.Long, org.tensorflow.proto.data.model.Model.ModelProto.Node> nodes_; private com.google.protobuf.MapField @@ -6230,43 +6415,43 @@ public org.tensorflow.proto.data.model.Model.ModelProto getDefaultInstanceForTyp static { java.lang.String[] descriptorData = { "\n%tensorflow/core/framework/model.proto\022" + - "\025tensorflow.data.model\"\207\010\n\nModelProto\022;\n" + - "\005nodes\030\001 \003(\0132,.tensorflow.data.model.Mod" + - "elProto.NodesEntry\022\016\n\006output\030\002 \001(\003\022\022\n\nid" + - "_counter\030\003 \001(\003\022Q\n\023optimization_params\030\005 " + - "\001(\01324.tensorflow.data.model.ModelProto.O" + - "ptimizationParams\022\021\n\tgap_times\030\006 \003(\004\032\277\004\n" + - "\004Node\022\n\n\002id\030\001 \001(\003\022\014\n\004name\030\002 \001(\t\022\020\n\010autot" + - "une\030\003 \001(\010\022\026\n\016buffered_bytes\030\004 \001(\003\022\031\n\021buf" + - "fered_elements\030\005 \001(\003\022\026\n\016bytes_consumed\030\006" + - " \001(\003\022\026\n\016bytes_produced\030\007 \001(\003\022\024\n\014num_elem" + - "ents\030\010 \001(\003\022\027\n\017processing_time\030\t \001(\003\022\026\n\016r" + - "ecord_metrics\030\n \001(\010\022D\n\nparameters\030\013 \003(\0132" + - "0.tensorflow.data.model.ModelProto.Node." + - "Parameter\022!\n\031input_processing_time_sum\030\014" + - " \001(\001\022#\n\033input_processing_time_count\030\r \001(" + - "\003\022\016\n\006inputs\030\016 \003(\003\0224\n\nnode_class\030\017 \001(\0162 ." + - "tensorflow.data.model.NodeClass\022\r\n\005ratio" + - "\030\020 \001(\001\022\024\n\014memory_ratio\030\021 \001(\001\032h\n\tParamete" + - "r\022\014\n\004name\030\001 \001(\t\022\r\n\005value\030\002 \001(\001\022\023\n\013state_" + - "value\030\003 \001(\001\022\013\n\003min\030\004 \001(\001\022\013\n\003max\030\005 \001(\001\022\017\n" + - "\007tunable\030\006 \001(\010\032T\n\nNodesEntry\022\013\n\003key\030\001 \001(" + - "\003\0225\n\005value\030\002 \001(\0132&.tensorflow.data.model" + - ".ModelProto.Node:\0028\001\032\223\001\n\022OptimizationPar" + - "ams\022;\n\talgorithm\030\001 \001(\0162(.tensorflow.data" + - ".model.AutotuneAlgorithm\022\022\n\ncpu_budget\030\002" + - " \001(\003\022\022\n\nram_budget\030\003 \001(\003\022\030\n\020model_input_" + - "time\030\004 \001(\001J\004\010\004\020\005*\234\001\n\tNodeClass\022\013\n\007UNKNOW" + - "N\020\000\022\023\n\017INTERLEAVE_MANY\020\001\022\031\n\025ASYNC_INTERL" + - "EAVE_MANY\020\002\022\017\n\013KNOWN_RATIO\020\003\022\025\n\021ASYNC_KN" + - "OWN_RATIO\020\004\022\021\n\rUNKNOWN_RATIO\020\005\022\027\n\023ASYNC_" + - "UNKNOWN_RATIO\020\006*l\n\021AutotuneAlgorithm\022\013\n\007" + - "DEFAULT\020\000\022\016\n\nHILL_CLIMB\020\001\022\024\n\020GRADIENT_DE" + - "SCENT\020\002\022\023\n\017MAX_PARALLELISM\020\003\022\017\n\013STAGE_BA" + - "SED\020\004Br\n\037org.tensorflow.proto.data.model" + - "ZLgithub.com/tensorflow/tensorflow/tenso" + - "rflow/go/core/framework/model_go_proto\370\001" + - "\001b\006proto3" + "\025tensorflow.data.model\"\235\010\n\nModelProto\022\024\n" + + "\014dataset_name\030\007 \001(\t\022;\n\005nodes\030\001 \003(\0132,.ten" + + "sorflow.data.model.ModelProto.NodesEntry" + + "\022\016\n\006output\030\002 \001(\003\022\022\n\nid_counter\030\003 \001(\003\022Q\n\023" + + "optimization_params\030\005 \001(\01324.tensorflow.d" + + "ata.model.ModelProto.OptimizationParams\022" + + "\021\n\tgap_times\030\006 \003(\004\032\277\004\n\004Node\022\n\n\002id\030\001 \001(\003\022" + + "\014\n\004name\030\002 \001(\t\022\020\n\010autotune\030\003 \001(\010\022\026\n\016buffe" + + "red_bytes\030\004 \001(\003\022\031\n\021buffered_elements\030\005 \001" + + "(\003\022\026\n\016bytes_consumed\030\006 \001(\003\022\026\n\016bytes_prod" + + "uced\030\007 \001(\003\022\024\n\014num_elements\030\010 \001(\003\022\027\n\017proc" + + "essing_time\030\t \001(\003\022\026\n\016record_metrics\030\n \001(" + + "\010\022D\n\nparameters\030\013 \003(\01320.tensorflow.data." + + "model.ModelProto.Node.Parameter\022!\n\031input" + + "_processing_time_sum\030\014 \001(\001\022#\n\033input_proc" + + "essing_time_count\030\r \001(\003\022\016\n\006inputs\030\016 \003(\003\022" + + "4\n\nnode_class\030\017 \001(\0162 .tensorflow.data.mo" + + "del.NodeClass\022\r\n\005ratio\030\020 \001(\001\022\024\n\014memory_r" + + "atio\030\021 \001(\001\032h\n\tParameter\022\014\n\004name\030\001 \001(\t\022\r\n" + + "\005value\030\002 \001(\001\022\023\n\013state_value\030\003 \001(\001\022\013\n\003min" + + "\030\004 \001(\001\022\013\n\003max\030\005 \001(\001\022\017\n\007tunable\030\006 \001(\010\032T\n\n" + + "NodesEntry\022\013\n\003key\030\001 \001(\003\0225\n\005value\030\002 \001(\0132&" + + ".tensorflow.data.model.ModelProto.Node:\002" + + "8\001\032\223\001\n\022OptimizationParams\022;\n\talgorithm\030\001" + + " \001(\0162(.tensorflow.data.model.AutotuneAlg" + + "orithm\022\022\n\ncpu_budget\030\002 \001(\003\022\022\n\nram_budget" + + "\030\003 \001(\003\022\030\n\020model_input_time\030\004 \001(\001J\004\010\004\020\005*\234" + + "\001\n\tNodeClass\022\013\n\007UNKNOWN\020\000\022\023\n\017INTERLEAVE_" + + "MANY\020\001\022\031\n\025ASYNC_INTERLEAVE_MANY\020\002\022\017\n\013KNO" + + "WN_RATIO\020\003\022\025\n\021ASYNC_KNOWN_RATIO\020\004\022\021\n\rUNK" + + "NOWN_RATIO\020\005\022\027\n\023ASYNC_UNKNOWN_RATIO\020\006*l\n" + + "\021AutotuneAlgorithm\022\013\n\007DEFAULT\020\000\022\016\n\nHILL_" + + "CLIMB\020\001\022\024\n\020GRADIENT_DESCENT\020\002\022\023\n\017MAX_PAR" + + "ALLELISM\020\003\022\017\n\013STAGE_BASED\020\004Br\n\037org.tenso" + + "rflow.proto.data.modelZLgithub.com/tenso" + + "rflow/tensorflow/tensorflow/go/core/fram" + + "ework/model_go_proto\370\001\001b\006proto3" }; descriptor = com.google.protobuf.Descriptors.FileDescriptor .internalBuildGeneratedFileFrom(descriptorData, @@ -6277,7 +6462,7 @@ public org.tensorflow.proto.data.model.Model.ModelProto getDefaultInstanceForTyp internal_static_tensorflow_data_model_ModelProto_fieldAccessorTable = new com.google.protobuf.GeneratedMessageV3.FieldAccessorTable( internal_static_tensorflow_data_model_ModelProto_descriptor, - new java.lang.String[] { "Nodes", "Output", "IdCounter", "OptimizationParams", "GapTimes", }); + new java.lang.String[] { "DatasetName", "Nodes", "Output", "IdCounter", "OptimizationParams", "GapTimes", }); internal_static_tensorflow_data_model_ModelProto_Node_descriptor = internal_static_tensorflow_data_model_ModelProto_descriptor.getNestedTypes().get(0); internal_static_tensorflow_data_model_ModelProto_Node_fieldAccessorTable = new diff --git a/tensorflow-core/tensorflow-core-native/src/gen/resources/org/tensorflow/base_api/api_def_ComputeDedupDataSize.pbtxt b/tensorflow-core/tensorflow-core-native/src/gen/resources/org/tensorflow/base_api/api_def_ComputeDedupDataSize.pbtxt new file mode 100644 index 00000000000..67b4a3f0761 --- /dev/null +++ b/tensorflow-core/tensorflow-core-native/src/gen/resources/org/tensorflow/base_api/api_def_ComputeDedupDataSize.pbtxt @@ -0,0 +1,22 @@ +op { + graph_op_name: "ComputeDedupDataSize" + visibility: HIDDEN + out_arg { + name: "num_elements" + description: <