Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 32 additions & 31 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include <cstdint>
#include <cstring>
#include <fstream>
#include <iterator>
#include <numeric>
#include <optional>
#include <sstream>
#include <string>
Expand Down Expand Up @@ -268,20 +270,18 @@ std::vector<std::string> GetXlaDevices(
return xla_devices;
}

std::vector<XLATensorPtr> GetXlaTensors(const std::vector<at::Tensor>& tensors,
bool want_all) {
// Collects all valid `XLATensorPtr` out of `tensors`.
//
// Iterates through `tensors`, collecting every `XLATensorPtr` value,
// ignoring those that return with a non-ok status.
static std::vector<XLATensorPtr> CollectXlaTensors(
const std::vector<at::Tensor>& tensors) {
std::vector<XLATensorPtr> xtensors;
xtensors.reserve(tensors.size());
if (want_all) {
for (auto& tensor : tensors) {
xtensors.push_back(GetValueOrThrow(bridge::GetXlaTensor(tensor)));
}
} else {
for (auto& tensor : tensors) {
auto xtensor_status = bridge::GetXlaTensor(tensor);
if (xtensor_status.ok()) {
xtensors.push_back(std::move(xtensor_status).value());
}
for (auto& tensor : tensors) {
auto xla_tensor_status = bridge::GetXlaTensor(tensor);
if (xla_tensor_status.ok()) {
// Insert only those that can be successfully retrieved.
xtensors.push_back(std::move(xla_tensor_status).value());
}
}
return xtensors;
Expand Down Expand Up @@ -396,11 +396,11 @@ void AllReduceInPlace(const std::string& reduce_type,
bool pin_layout) {
TORCH_LAZY_FN_COUNTER_TIMED_TRACING("xla::");
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
GetValueOrThrow(bridge::GetXlaTensors(tensors));
tensor_methods::all_reduce(xtensors, GetReduceType(reduce_type), scale,
replica_groups, pin_layout);
std::vector<XLATensorPtr> new_xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
GetValueOrThrow(bridge::GetXlaTensors(tensors));
MaybeThrow(bridge::ReplaceXlaTensor(tensors, new_xtensors));
}

Expand Down Expand Up @@ -506,7 +506,8 @@ ReduceScatterCoalesced(const std::string& reduce_type,
double scale, int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors =
GetValueOrThrow(bridge::GetXlaTensors(inputs));
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::reduce_scatter_coalesced(
Expand All @@ -526,8 +527,9 @@ std::shared_ptr<torch::lazy::Value> ReduceScatterCoalescedOut(
int64_t scatter_dim, int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
GetValueOrThrow(bridge::GetXlaTensors(outputs));
std::vector<XLATensorPtr> xtensors =
GetValueOrThrow(bridge::GetXlaTensors(inputs));
torch::lazy::Value new_token;
new_token = tensor_methods::reduce_scatter_coalesced_out(
xtensors_out, xtensors, *token, GetReduceType(reduce_type), scale,
Expand Down Expand Up @@ -568,7 +570,7 @@ AllGatherCoalesced(const std::vector<at::Tensor>& tensors,
const std::vector<std::vector<int64_t>>& replica_groups,
bool pin_layout) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
GetValueOrThrow(bridge::GetXlaTensors(tensors));
std::vector<XLATensorPtr> result;
torch::lazy::Value new_token;
std::tie(result, new_token) = tensor_methods::all_gather_coalesced(
Expand All @@ -586,8 +588,9 @@ std::shared_ptr<torch::lazy::Value> AllGatherCoalescedOut(
int64_t shard_count,
const std::vector<std::vector<int64_t>>& replica_groups, bool pin_layout) {
std::vector<XLATensorPtr> xtensors_out =
GetXlaTensors(outputs, /*want_all=*/true);
std::vector<XLATensorPtr> xtensors = GetXlaTensors(inputs, /*want_all=*/true);
GetValueOrThrow(bridge::GetXlaTensors(outputs));
std::vector<XLATensorPtr> xtensors =
GetValueOrThrow(bridge::GetXlaTensors(inputs));
torch::lazy::Value new_token;
new_token = tensor_methods::all_gather_coalesced_out(
xtensors_out, xtensors, *token, dim, shard_count, replica_groups,
Expand Down Expand Up @@ -624,8 +627,7 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> CollectivePermute(
}

void OptimizationBarrier_(std::vector<at::Tensor>& tensors) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/false);
auto xtensors = CollectXlaTensors(tensors);
tensor_methods::optimization_barrier_(xtensors);
}

Expand Down Expand Up @@ -654,8 +656,7 @@ std::pair<at::Tensor, std::shared_ptr<torch::lazy::Value>> Recv(
void SyncTensors(const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, bool wait,
bool sync_xla_data, bool warm_up_cache_only = false) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/false);
std::vector<XLATensorPtr> xtensors = CollectXlaTensors(tensors);
XLAGraphExecutor::Get()->SyncTensorsGraph(&xtensors, devices, wait,
sync_xla_data, warm_up_cache_only);
}
Expand Down Expand Up @@ -704,8 +705,7 @@ uint64_t GetRngSeed(const std::string& device_str) {

std::string GetTensorsHloGraph(const std::vector<at::Tensor>& tensors,
EmitMode mode) {
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/false);
std::vector<XLATensorPtr> xtensors = CollectXlaTensors(tensors);
return XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode);
}

Expand Down Expand Up @@ -884,7 +884,8 @@ py::object GetRevisions() {
std::vector<at::Tensor> XlaUserComputation(
const std::string& opname, const std::vector<at::Tensor>& inputs,
runtime::ComputationClient::ComputationPtr computation) {
std::vector<XLATensorPtr> xinputs = GetXlaTensors(inputs, /*want_all=*/true);
std::vector<XLATensorPtr> xinputs =
GetValueOrThrow(bridge::GetXlaTensors(inputs));
std::vector<XLATensorPtr> xresults =
tensor_methods::user_computation(opname, xinputs, std::move(computation));
std::vector<at::Tensor> results;
Expand Down Expand Up @@ -1141,7 +1142,7 @@ class PyLoweringContext {
void Build(std::vector<at::Tensor> tensors) {
// Get the backing XLA tensors from the output torch tensor handles
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
GetValueOrThrow(bridge::GetXlaTensors(tensors));

// Get the lazy IR value from the output XLA tensors
std::vector<torch::lazy::Value> ir_values;
Expand All @@ -1168,7 +1169,7 @@ class PyLoweringContext {
std::vector<at::Tensor> additional_inputs_list = {}) {
// Get the backing XLA tensors from the output torch tensor handles
std::vector<XLATensorPtr> xtensors =
GetXlaTensors(tensors, /*want_all=*/true);
GetValueOrThrow(bridge::GetXlaTensors(tensors));

// Get the lazy IR value from the output XLA tensors
std::vector<torch::lazy::Value> ir_values;
Expand Down Expand Up @@ -2285,7 +2286,7 @@ void InitXlaModuleBindings(py::module m) {
xtensors =
XLAGraphExecutor::Get()->GetLiveTensors(&backend_device);
} else {
xtensors = GetXlaTensors(tensors, /*want_all=*/false);
xtensors = CollectXlaTensors(tensors);
}
return py::bytes(
XLAGraphExecutor::Get()->DumpHloComputation(xtensors, mode));
Expand Down
Loading