Skip to content

Conversation

ysiraichi
Copy link
Collaborator

This PR deduplicates GetXlaTensors().
Previously, there were 2 functions in the repository with the same name, similar functionality, different signature:

init_python_bindings.cpp

std::vector<XLATensorPtr> GetXlaTensors(const std::vector<at::Tensor>& tensors,
                                        bool want_all);

aten_xla_bridge.h

absl::StatusOr<std::vector<absl_nonnull XLATensorPtr>> GetXlaTensors(
    const at::ITensorListRef& tensors);
  • if want_all=true, it behaved similarly to GetXlaTensors()
  • otherwise, it collected all XLA tensors from the given list of tensors
  • the function in init_python_bindings.cpp lived under an anonymous namespace, not used anywhere else

Therefore, this PR introduces the following key changes:

  • Replaces all init_python_bindings.cpp GetXlaTensors(..., /* want_all= */ true) call sites with GetValueOrThrow(bridge::GetXlaTensors(...))
  • Rename init_python_bindings.cpp GetXlaTensors() to CollectXlaTensors(), avoiding functions with different semantics to have the same name
  • Re-define CollectXlaTensors() to have only the want_all=false behavior, removing want_all flag from the function parameters

Copy link
Collaborator

@zhanyong-wan zhanyong-wan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

@ysiraichi ysiraichi merged commit cb64f4c into master Jul 30, 2025
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants