Add MLIR side effects to tf.XlaCallModule
.
#60643
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This change introduces updates to the
tf.XlaCallModule
op in order to supportjax2tf
native serialization. Thetf.XlaCallModule
op contains the StableHLO module, which may involve calling TF host callback functions throughstablehlo.custom_call
. To enable proper functionality, the following modifications were made:The
Pure
trait in the automatically generatedtf.XlaCallModule
op's definition has been replaced with theMemoryEffects
trait.The
isStateful
flag has been set in the op declaration ofXlaCallModule
to indicate that it has stateful behavior.The TensorFlow side effect analysis has been updated to recursively analyze the TF host callback functions invoked by
tf.XlaCallModule
.These changes ensure better compatibility and alignment with the
jax2tf
native serialization process, allowing for improved handling of side effects and seamless integration with TensorFlow.PiperOrigin-RevId: 533635753