Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/cpp/cpp_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ std::vector<xla::ComputationClient::DataPtr> Execute(
lowering_ctx.AddResult(root);
}

xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build());
xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
xla::Shape shape = MakeShapeWithDeviceLayout(
program_shape.result(), static_cast<XlaDeviceType>(device.type()));
Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/ir_dump_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,10 @@ std::string DumpUtil::ToHlo(c10::ArrayRef<torch::lazy::Value> values,
const torch::lazy::BackendDevice& device) {
LoweringContext lowering_ctx("IrToHlo", device);
for (auto& ir_value : values) {
xla::XlaOp root = lowering_ctx.GetOutputOp(
lowering_ctx.AddResult(
torch::lazy::Output(ir_value.node.get(), ir_value.index));
lowering_ctx.AddResult(root);
}
xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build());
xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
return ConsumeValue(xla::util::GetComputationHloText(computation));
}

Expand Down
60 changes: 48 additions & 12 deletions torch_xla/csrc/lowering_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "torch/csrc/lazy/core/ir_metadata.h"
#include "torch_xla/csrc/computation.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/tensor_util.h"

namespace torch_xla {
Expand Down Expand Up @@ -75,15 +77,14 @@ class HloMetadataSetter {

LoweringContext::LoweringContext(const std::string& name,
torch::lazy::BackendDevice device)
: builder_(name), device_(std::move(device)) {}
: torch::lazy::LoweringContext(name, device), builder_(name) {}

LoweringContext::LoweringContext(
const std::string& name, torch::lazy::BackendDevice device,
absl::Span<const torch::lazy::Node* const> post_order,
c10::ArrayRef<const torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status)
: builder_(name),
device_(std::move(device)),
emit_status_(std::move(emit_status)) {
: torch::lazy::LoweringContext(name, device, {}, emit_status),
builder_(name) {
for (auto node : post_order) {
LowerNode(node);
}
Expand Down Expand Up @@ -114,11 +115,6 @@ const std::vector<size_t>& LoweringContext::GetParameterSequence() const {
return parameter_sequence_;
}

size_t LoweringContext::AddResult(xla::XlaOp op) {
root_tuple_.push_back(std::move(op));
return root_tuple_.size() - 1;
}

xla::XlaOp LoweringContext::GetResult(size_t index) const {
return root_tuple_.at(index);
}
Expand All @@ -127,15 +123,15 @@ void LoweringContext::SetResult(size_t index, xla::XlaOp op) {
root_tuple_.at(index) = std::move(op);
}

xla::StatusOr<xla::XlaComputation> LoweringContext::Build() {
xla::StatusOr<xla::XlaComputation> LoweringContext::BuildXla() {
if (!root_tuple_.empty()) {
xla::XlaOp root = xla::Tuple(builder(), root_tuple_);
return builder()->Build(root);
}
return builder()->Build();
}

xla::StatusOr<xla::XlaComputation> LoweringContext::Build(xla::XlaOp root) {
xla::StatusOr<xla::XlaComputation> LoweringContext::BuildXla(xla::XlaOp root) {
XLA_CHECK(root_tuple_.empty());
return builder()->Build(root);
}
Expand Down Expand Up @@ -195,4 +191,44 @@ void LoweringContext::ReportBuilderError(const torch::lazy::Node* node,
throw std::runtime_error(ss.str());
}

void LoweringContext::SetUpAlias(const std::vector<int64_t>& output_index,
int64_t param_number,
const std::vector<int64_t>& param_index,
bool must_alias) {
XLA_CHECK_EQ(output_index.size(), 1);
XLA_CHECK_EQ(param_index.size(), 1);
builder_.SetUpAlias({output_index[0]}, param_number, {param_index[0]});
}

bool LoweringContext::CheckResultShape(
const torch::lazy::BackendDataPtr& parameter_data, size_t result_idx) {
xla::XlaOp root = GetResult(result_idx);
const xla::Shape& root_shape = XlaHelpers::ShapeOfXlaOp(root);
return UnwrapXlaData(parameter_data)->shape() == root_shape;
}

size_t LoweringContext::AddResult(const torch::lazy::Output& output) {
root_tuple_.push_back(GetOutputOp(output));
return root_tuple_.size() - 1;
}

size_t LoweringContext::AddResult(xla::XlaOp op) {
root_tuple_.push_back(op);
return root_tuple_.size() - 1;
}

void LoweringContext::AddParameter(const torch::lazy::Output& output,
size_t index,
const torch::lazy::Shape& shape,
const std::string& name) {
XLA_ERROR() << "not implemented";
return;
}

torch::lazy::ComputationPtr LoweringContext::Build() {
xla::XlaComputation xla_computation = ConsumeValue(BuildXla());
return std::make_shared<torch_xla::Computation>(builder_.name(),
std::move(xla_computation));
}

} // namespace torch_xla
36 changes: 22 additions & 14 deletions torch_xla/csrc/lowering_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,20 @@
#include "tensorflow/compiler/xla/xla_client/computation_client.h"
#include "tensorflow/core/platform/macros.h"
#include "torch/csrc/lazy/backend/backend_data.h"
#include "torch/csrc/lazy/backend/lowering_context.h"
#include "torch/csrc/lazy/core/ir_util.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_util.h"

namespace torch_xla {

class LoweringContext {
class LoweringContext : public torch::lazy::LoweringContext {
public:
explicit LoweringContext(const std::string& name,
torch::lazy::BackendDevice device);
LoweringContext(const std::string& name, torch::lazy::BackendDevice device,
absl::Span<const torch::lazy::Node* const> post_order,
c10::ArrayRef<const torch::lazy::Node*> post_order,
torch::lazy::Util::EmissionMap emit_status);

xla::XlaBuilder* builder() { return &builder_; }
Expand All @@ -43,10 +44,6 @@ class LoweringContext {

const std::vector<size_t>& GetParameterSequence() const;

// Adds the output of a given operation to the result tuple. Returns the index
// of the output within the tuple.
size_t AddResult(xla::XlaOp op);

xla::XlaOp GetResult(size_t index) const;

void SetResult(size_t index, xla::XlaOp op);
Expand All @@ -63,19 +60,34 @@ class LoweringContext {

// Build the XLA computation capturing all the operations created with the
// embedded XLA builder (returned by the builder() API).
xla::StatusOr<xla::XlaComputation> Build();
xla::StatusOr<xla::XlaComputation> BuildXla();

// Build the XLA computation capturing all the operations created with the
// embedded XLA builder (returned by the builder() API).
// Uses root as return value forthe computation. It is an error to use this
// API after having called the AddResult() API.
xla::StatusOr<xla::XlaComputation> Build(xla::XlaOp root);
xla::StatusOr<xla::XlaComputation> BuildXla(xla::XlaOp root);

// Lowers a single IR node. All the inputs to the node must have a lowering
// before calling this API. Returns the generated XLA operations.
XlaOpVector LowerNode(const torch::lazy::Node* node);

size_t GetEmittedNodeCount() const { return emit_status_.size(); }
void SetUpAlias(const std::vector<int64_t>& output_index,
int64_t param_number, const std::vector<int64_t>& param_index,
bool must_alias = false) override;

bool CheckResultShape(const torch::lazy::BackendDataPtr& parameter_data,
size_t result_idx) override;

size_t AddResult(const torch::lazy::Output& output) override;

size_t AddResult(xla::XlaOp op);

void AddParameter(const torch::lazy::Output& output, size_t index,
const torch::lazy::Shape& shape,
const std::string& name) override;

torch::lazy::ComputationPtr Build() override;

private:
struct Parameter {
Expand All @@ -88,14 +100,10 @@ class LoweringContext {
const char* error_msg);

xla::XlaBuilder builder_;
torch::lazy::BackendDevice device_;
std::vector<torch::lazy::BackendDataPtr> parameters_;
std::unordered_map<torch::lazy::BackendData::Handle, Parameter>
parameters_map_;
std::vector<size_t> parameter_sequence_;
std::vector<xla::XlaOp> root_tuple_;
OutputMap<xla::XlaOp> emitted_outputs_;
torch::lazy::Util::EmissionMap emit_status_;
};
}; // namespace torch_xla

} // namespace torch_xla
2 changes: 1 addition & 1 deletion torch_xla/csrc/op_by_op_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ xla::XlaComputation BuildNodeComputation(
for (auto& xla_op : loctx.LowerNode(node)) {
loctx.AddResult(xla_op);
}
return ConsumeValue(loctx.Build());
return ConsumeValue(loctx.BuildXla());
}

torch::lazy::hash_t GetNodesKeySeed(const std::string& device,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1656,7 +1656,7 @@ XLATensor::CompilationResult XLATensor::Compile(
BuildInputOutputAliases(tensors, coll.indices, &lowering_ctx);
}

xla::XlaComputation computation = ConsumeValue(lowering_ctx.Build());
xla::XlaComputation computation = ConsumeValue(lowering_ctx.BuildXla());
xla::ProgramShape program_shape = ConsumeValue(computation.GetProgramShape());
xla::Shape shape = MakeShapeWithDeviceLayout(
program_shape.result(), static_cast<XlaDeviceType>(coll.device.type()));
Expand Down