Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions torch_xla/csrc/ir_builder.h
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "torch/csrc/lazy/core/ir.h"
#include "torch/csrc/lazy/core/ir_builder.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ops/as_strided.h"
#include "torch_xla/csrc/ops/cast.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/ops/diagonal.h"
#include "torch_xla/csrc/ops/expand.h"
#include "torch_xla/csrc/ops/generic.h"
#include "torch_xla/csrc/ops/ops.h"
#include "torch_xla/csrc/tensor_util.h"

namespace torch_xla {

struct XLAIrBuilder : IrBuilder {
struct XLAIrBuilder : torch::lazy::IrBuilder {
torch::lazy::NodePtr MakeDeviceData(
const std::shared_ptr<BackendData>& data) const override {
const std::shared_ptr<torch::lazy::BackendData>& data) const override {
return torch::lazy::MakeNode<DeviceData>(data);
}

Expand All @@ -30,27 +33,31 @@ struct XLAIrBuilder : IrBuilder {
torch::lazy::NodePtr MakeView(
const torch::lazy::Value& input0,
const std::vector<int64_t>& output_size) const override {
return torch::lazy::MakeNode<ViewOp>(input0, output_size);
// TODO(JackCAoG): use functionization pass instead
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why are we changing this to functionization pass?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so after functionization pass, we no longer see the view operator, all of the view op becomes a new view_copy op. The complicated part of view is that they should share the same memory space. This is not possible for pytorch/xla right now because XLA does not expose its memory space to user. What we did is we create two node and whenever an in place operaton being done on one of the op, we replay that operation to the other node(after reshape).

after functionization pass, we no longer need to maintain this relationship in pytorch/xla layer. We can treat view op like any other op. PyTorch core will handle to apply the same in_place operation to the Node that are views of each other.

return nullptr;
}
torch::lazy::NodePtr MakeCast(const torch::lazy::Value& input0,
const at::ScalarType& dtype,
const c10::optional<at::ScalarType>& stype =
c10::nullopt) const override {
return torch::lazy::MakeNode<Cast>(input0, dtype, stype);
}
torch::lazy::NodePtr MakeTensorList(const OpList& inputs) const override {
torch::lazy::NodePtr MakeTensorList(
const torch::lazy::OpList& inputs) const override {
// TODO(JackCaoG): implement tensorList IR. This is used by codegen.
XLA_ERROR() << "Need to implement";
return nullptr;
}
// Generic needs cleanup
torch::lazy::NodePtr MakeGeneric(
const OpKind& op, const OpList& operands, const Shape& shape,
const size_t& num_outputs = 1,
const hash_t& hash_seed =
const torch::lazy::OpKind& op, const torch::lazy::OpList& operands,
const torch::lazy::Shape& shape, const size_t& num_outputs = 1,
const torch::lazy::hash_t& hash_seed =
static_cast<uint32_t>(0x5a2d296e9)) const override {
return torch::lazy::MakeNode<Generic>(op, operands, shape, num_outputs,
hash_seed);
// TODO(JackCaoG): ltc generic op does not take lowering function
// return torch::lazy::MakeNode<Generic>(
// op, operands, MakeXlaShapeFromLazyShape(shape, *GetDefaultDevice()),
// num_outputs, hash_seed);
}

// We should use functionization pass for view ops when migrating to the LTC.
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/xla_backend_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "torch_xla/csrc/aten_xla_bridge.h"
#include "torch_xla/csrc/computation.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir_builder.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/tensor.h"
Expand All @@ -31,8 +32,8 @@ class XlaBackendImpl : public torch::lazy::BackendImplInterface {
}

const torch::lazy::IrBuilder* GetIrBuilder() const override {
XLA_ERROR() << "Not implemented yet";
return 0;
static const torch::lazy::IrBuilder* builder = new XLAIrBuilder();
return builder;
}

torch::lazy::BackendDataPtr MakeComputationDataFromTensor(
Expand Down