diff --git a/torch_xla/csrc/init_python_bindings.cpp b/torch_xla/csrc/init_python_bindings.cpp index bd4152aee81..da2701bb21d 100644 --- a/torch_xla/csrc/init_python_bindings.cpp +++ b/torch_xla/csrc/init_python_bindings.cpp @@ -15,6 +15,8 @@ #include #include #include +#include +#include #include #include #include @@ -268,20 +270,18 @@ std::vector GetXlaDevices( return xla_devices; } -std::vector GetXlaTensors(const std::vector& tensors, - bool want_all) { +// Collects all valid `XLATensorPtr` out of `tensors`. +// +// Iterates through `tensors`, collecting every `XLATensorPtr` value, +// ignoring those that return with a non-ok status. +static std::vector CollectXlaTensors( + const std::vector& tensors) { std::vector xtensors; - xtensors.reserve(tensors.size()); - if (want_all) { - for (auto& tensor : tensors) { - xtensors.push_back(GetValueOrThrow(bridge::GetXlaTensor(tensor))); - } - } else { - for (auto& tensor : tensors) { - auto xtensor_status = bridge::GetXlaTensor(tensor); - if (xtensor_status.ok()) { - xtensors.push_back(std::move(xtensor_status).value()); - } + for (auto& tensor : tensors) { + auto xla_tensor_status = bridge::GetXlaTensor(tensor); + if (xla_tensor_status.ok()) { + // Insert only those that can be successfully retrieved. + xtensors.push_back(std::move(xla_tensor_status).value()); } } return xtensors; @@ -396,11 +396,11 @@ void AllReduceInPlace(const std::string& reduce_type, bool pin_layout) { TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::"); std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); tensor_methods::all_reduce(xtensors, GetReduceType(reduce_type), scale, replica_groups, pin_layout); std::vector new_xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); MaybeThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors)); } @@ -506,7 +506,8 @@ ReduceScatterCoalesced(const std::string& reduce_type, double scale, int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { - std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + std::vector xtensors = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); std::vector result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced( @@ -526,8 +527,9 @@ std::shared_ptr ReduceScatterCoalescedOut( int64_t scatter_dim, int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { std::vector xtensors_out = - GetXlaTensors(outputs, /*want_all=*/true); - std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(outputs)); + std::vector xtensors = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); torch::lazy::Value new_token; new_token = tensor_methods::reduce_scatter_coalesced_out( xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale, @@ -568,7 +570,7 @@ AllGatherCoalesced(const std::vector& tensors, const std::vector>& replica_groups, bool pin_layout) { std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); std::vector result; torch::lazy::Value new_token; std::tie(result, new_token) = tensor_methods::all_gather_coalesced( @@ -586,8 +588,9 @@ std::shared_ptr AllGatherCoalescedOut( int64_t shard_count, const std::vector>& replica_groups, bool pin_layout) { std::vector xtensors_out = - GetXlaTensors(outputs, /*want_all=*/true); - std::vector xtensors = GetXlaTensors(inputs, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(outputs)); + std::vector xtensors = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); torch::lazy::Value new_token; new_token = tensor_methods::all_gather_coalesced_out( xtensors_out, xtensors, *token, dim, shard_count, replica_groups, @@ -624,8 +627,7 @@ std::pair> CollectivePermute( } void OptimizationBarrier_(std::vector& tensors) { - std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/false); + auto xtensors = CollectXlaTensors(tensors); tensor_methods::optimization_barrier_(xtensors); } @@ -654,8 +656,7 @@ std::pair> Recv( void SyncTensors(const std::vector& tensors, const std::vector& devices, bool wait, bool sync_xla_data, bool warm_up_cache_only = false) { - std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/false); + std::vector xtensors = CollectXlaTensors(tensors); XLAGraphExecutor::Get()->SyncTensorsGraph(&xtensors, devices, wait, sync_xla_data, warm_up_cache_only); } @@ -704,8 +705,7 @@ uint64_t GetRngSeed(const std::string& device_str) { std::string GetTensorsHloGraph(const std::vector& tensors, EmitMode mode) { - std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/false); + std::vector xtensors = CollectXlaTensors(tensors); return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode); } @@ -884,7 +884,8 @@ py::object GetRevisions() { std::vector XlaUserComputation( const std::string& opname, const std::vector& inputs, runtime::ComputationClient::ComputationPtr computation) { - std::vector xinputs = GetXlaTensors(inputs, /*want_all=*/true); + std::vector xinputs = + GetValueOrThrow(bridge::GetXlaTensors(inputs)); std::vector xresults = tensor_methods::user_computation(opname, xinputs, std::move(computation)); std::vector results; @@ -1141,7 +1142,7 @@ class PyLoweringContext { void Build(std::vector tensors) { // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); // Get the lazy IR value from the output XLA tensors std::vector ir_values; @@ -1168,7 +1169,7 @@ class PyLoweringContext { std::vector additional_inputs_list = {}) { // Get the backing XLA tensors from the output torch tensor handles std::vector xtensors = - GetXlaTensors(tensors, /*want_all=*/true); + GetValueOrThrow(bridge::GetXlaTensors(tensors)); // Get the lazy IR value from the output XLA tensors std::vector ir_values; @@ -2285,7 +2286,7 @@ void InitXlaModuleBindings(py::module m) { xtensors = XLAGraphExecutor::Get()->GetLiveTensors(&backend_device); } else { - xtensors = GetXlaTensors(tensors, /*want_all=*/false); + xtensors = CollectXlaTensors(tensors); } return py::bytes( XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode));