Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 1 addition & 33 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -815,39 +815,7 @@ XLAGraphExecutor::PostOrderData XLAGraphExecutor::RunPostOrder(
SyncTensorCollection* coll) {
tensorflow::profiler::TraceMe activity(
"RunPostOrder", tensorflow::profiler::TraceMeLevel::kInfo);
std::vector<const torch::lazy::Node*> roots;
roots.reserve(ir_values.size());
for (auto ir_value : ir_values) {
roots.push_back(ir_value.node.get());
}
PostOrderData po_data;
po_data.post_order =
torch::lazy::Util::ComputePostOrder(roots, &po_data.emission_map);
std::unordered_map<xla::ComputationClient::Data::OpaqueHandle, size_t>
data_handles;

for (auto node : po_data.post_order) {
const auto backend_data =
torch::lazy::getBackend()->GetComputationDataFromNode(node);
if (backend_data != nullptr) {
/* Acceptable race condition: HasValue may return false. This is OK
* since the conditional barrier is a performance optimization. */
if (!backend_data->HasValue()) {
TensorCollectionBarrier(coll);
}
xla::ComputationClient::Data::OpaqueHandle handle =
backend_data->GetHandle();
auto it = data_handles.find(handle);
if (it != data_handles.end()) {
po_data.parameter_sequence.push_back(it->second);
} else {
po_data.parameter_sequence.push_back(po_data.parameters_data.size());
data_handles[handle] = po_data.parameters_data.size();
po_data.parameters_data.push_back(backend_data);
}
}
}
return po_data;
return torch::lazy::LazyGraphExecutor::RunPostOrder(ir_values, coll);
}

XLAGraphExecutor::ComputationCache::TypePtr
Expand Down
8 changes: 7 additions & 1 deletion torch_xla/csrc/xla_graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
// that is affected by the view tensor.
void ApplyEagerSync(std::vector<XLATensorPtr>& tensors);

// We don't use the upstream GetDeviceDataIrValue to have the
// xla::PrimitiveType.
torch::lazy::Value GetDeviceDataIrValue(
const at::Scalar& value, xla::PrimitiveType type,
const torch::lazy::BackendDevice& device);
Expand Down Expand Up @@ -243,6 +245,8 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
absl::Span<const size_t> indices,
const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec);

// We don't use upstream ExtractIRAndPrepareTensorData as we need to
// instantiate xla::shape.
void ExtractIRAndPrepareXlaData_(
std::vector<XLATensorPtr>* tensors, const SyncTensorsConfig& config,
const absl::Span<const size_t> indices,
Expand Down Expand Up @@ -271,8 +275,9 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
std::string device, ComputationCache::TypePtr cached_computation,
const std::vector<torch::lazy::BackendDataPtr>& tensor_data_vec);

// Override to enable profiler.
PostOrderData RunPostOrder(const std::vector<torch::lazy::Value>& ir_values,
SyncTensorCollection* coll);
SyncTensorCollection* coll) final;

// We don't use the upstream LookupCachedCompile since
// our CachedComputation is different from upstream.
Expand All @@ -290,6 +295,7 @@ class XLAGraphExecutor : public torch::lazy::LazyGraphExecutor {
const std::vector<XLATensorPtr>& tensors,
absl::Span<const size_t> indices, LoweringContext* lowering_ctx);

// We don't use upstream Compile to have BuildInputOutputAliases.
CompilationResult Compile(const std::vector<XLATensorPtr>& tensors,
absl::Span<const std::string> devices,
const SyncTensorCollection& coll,
Expand Down