diff --git a/torch_xla/csrc/aten_xla_type.cpp b/torch_xla/csrc/aten_xla_type.cpp index e415f2caf3f4..eeec34d518c2 100644 --- a/torch_xla/csrc/aten_xla_type.cpp +++ b/torch_xla/csrc/aten_xla_type.cpp @@ -1821,7 +1821,7 @@ at::Tensor XLANativeFunctions::log(const at::Tensor& self) { at::Tensor XLANativeFunctions::log10(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::log_base( - bridge::GetXlaTensor(self), ir::OpKind(at::aten::log10), 10.0)); + bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log10), 10.0)); } at::Tensor XLANativeFunctions::log1p(const at::Tensor& self) { @@ -1833,7 +1833,7 @@ at::Tensor XLANativeFunctions::log1p(const at::Tensor& self) { at::Tensor XLANativeFunctions::log2(const at::Tensor& self) { XLA_FN_COUNTER("xla::"); return bridge::AtenFromXlaTensor(XLATensor::log_base( - bridge::GetXlaTensor(self), ir::OpKind(at::aten::log2), 2.0)); + bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log2), 2.0)); } at::Tensor XLANativeFunctions::log_sigmoid_backward( diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 98474897131a..589ab2734a58 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -127,16 +127,8 @@ torch::lazy::hash_t Value::hash() const { return torch::lazy::HashCombine(node->hash(), index); } -OpKind OpKind::Get(const std::string& name) { - return OpKind(c10::Symbol::fromQualString(name)); -} - -torch::lazy::hash_t OpKind::hash() const { - return torch::lazy::StringHash(op.toQualString()); -} - -Node::Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs, - torch::lazy::hash_t hash_seed) +Node::Node(torch::lazy::OpKind op, OpList operands, xla::Shape shape, + size_t num_outputs, torch::lazy::hash_t hash_seed) : op_(std::move(op)), num_outputs_(num_outputs), shape_(std::move(shape)), @@ -150,7 +142,7 @@ Node::Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs, } } -Node::Node(OpKind op, OpList operands, +Node::Node(torch::lazy::OpKind op, OpList operands, const std::function& shape_fn, size_t num_outputs, torch::lazy::hash_t hash_seed) : Node(std::move(op), operands, xla::Shape(), num_outputs, hash_seed) { @@ -159,7 +151,7 @@ Node::Node(OpKind op, OpList operands, shape_ = GetOpShape(shape_fn); } -Node::Node(OpKind op, xla::Shape shape, size_t num_outputs, +Node::Node(torch::lazy::OpKind op, xla::Shape shape, size_t num_outputs, torch::lazy::hash_t hash_seed) : op_(std::move(op)), num_outputs_(num_outputs), @@ -247,7 +239,8 @@ XlaOpVector Node::Lower(LoweringContext* loctx) const { XLA_ERROR() << "Lowering not implemented for node: " << *this; } -torch::lazy::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape, +torch::lazy::hash_t Node::GetOpHash(torch::lazy::OpKind op, + const xla::Shape& shape, torch::lazy::hash_t hash_seed) { torch::lazy::hash_t h = torch::lazy::HashCombine(op.hash(), torch::lazy::Hash(shape.ToString())); diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index e0710f6333dc..027c21caca17 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -18,6 +18,7 @@ #include "tensorflow/compiler/xla/xla_client/types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "torch/csrc/lazy/core/hash.h" +#include "torch/csrc/lazy/core/ir.h" #include "torch_xla/csrc/python_util.h" namespace torch_xla { @@ -134,34 +135,6 @@ struct Value { size_t index = 0; }; -// The Kind of operation a Node can be associated to. -struct OpKind { - OpKind() = default; - explicit OpKind(c10::Symbol op) : op(std::move(op)) {} - - bool operator==(const OpKind& rhs) const { return op == rhs.op; } - bool operator!=(const OpKind& rhs) const { return !operator==(rhs); } - bool operator<(const OpKind& rhs) const { - return c10::unique_t(op) < c10::unique_t(rhs.op); - } - - torch::lazy::hash_t hash() const; - - std::string ToString() const { return op.toQualString(); } - - // Retrieves an existing operation object, or creates a new one. Operations - // that are specific to the XLA side, should live within the 'xla::' - // namespace. - static OpKind Get(const std::string& name); - - c10::Symbol op; -}; - -inline std::ostream& operator<<(std::ostream& stream, const OpKind& op) { - stream << op.ToString(); - return stream; -} - using OpList = absl::Span; // A node in the graph. Nodes for operations which requires extra data to be @@ -175,22 +148,23 @@ class Node { // Creates a new node with the given op name. The op is a unique identifier // for the operation. The num_outputs tells how many outputs a given operation // generates. - Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs = 1, + Node(torch::lazy::OpKind op, OpList operands, xla::Shape shape, + size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); // Same as the constructor above, but the shape is generated by a function, // only if needed (shape cache miss). - Node(OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs = 1, + Node(torch::lazy::OpKind op, OpList operands, + const std::function& shape_fn, size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); // Contructor used to create leaf nodes. - Node(OpKind op, xla::Shape shape, size_t num_outputs, + Node(torch::lazy::OpKind op, xla::Shape shape, size_t num_outputs, torch::lazy::hash_t hash_seed); virtual ~Node(); - const OpKind& op() const { return op_; } + const torch::lazy::OpKind& op() const { return op_; } size_t num_outputs() const { return num_outputs_; } @@ -247,13 +221,14 @@ class Node { xla::Shape GetOpShape(const std::function& shape_fn) const; - static torch::lazy::hash_t GetOpHash(OpKind op, const xla::Shape& shape, + static torch::lazy::hash_t GetOpHash(torch::lazy::OpKind op, + const xla::Shape& shape, torch::lazy::hash_t hash_seed); static std::vector GetFrameInfo(); // The ID of the operation captured by this node. - OpKind op_; + torch::lazy::OpKind op_; size_t num_outputs_ = 1; xla::Shape shape_; // A node holds a real reference to its operands. @@ -295,7 +270,7 @@ NodePtr MakeNode(Args&&... args) { } template -T* NodeCast(const Node* node, OpKind op) { +T* NodeCast(const Node* node, torch::lazy::OpKind op) { if (op != node->op()) { return nullptr; } diff --git a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp index e04000f7e7e8..8d4a8c34bd65 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input, AdaptiveAvgPool2d::AdaptiveAvgPool2d(const Value& input, std::vector output_size) - : Node(ir::OpKind(at::aten::adaptive_avg_pool2d), {input}, + : Node(torch::lazy::OpKind(at::aten::adaptive_avg_pool2d), {input}, [&]() { return NodeOutputShape(input, output_size); }, /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} diff --git a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp index a9fd7a08028b..a389bb52dec4 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input, AdaptiveAvgPool3d::AdaptiveAvgPool3d(const Value& input, std::vector output_size) - : Node(ir::OpKind(at::aten::adaptive_avg_pool3d), {input}, + : Node(torch::lazy::OpKind(at::aten::adaptive_avg_pool3d), {input}, [&]() { return NodeOutputShape(input, output_size); }, /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} diff --git a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp index 88aa7c80a7af..846a3379eb98 100644 --- a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp @@ -25,7 +25,7 @@ xla::Shape NodeOutputShape(const Value& input, AdaptiveMaxPool2d::AdaptiveMaxPool2d(const Value& input, std::vector output_size) - : Node(ir::OpKind(at::aten::adaptive_max_pool2d), {input}, + : Node(torch::lazy::OpKind(at::aten::adaptive_max_pool2d), {input}, [&]() { return NodeOutputShape(input, output_size); }, /*num_outputs=*/2, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} diff --git a/torch_xla/csrc/ops/all.cpp b/torch_xla/csrc/ops/all.cpp index c08da0bfc4dc..4108c680690f 100644 --- a/torch_xla/csrc/ops/all.cpp +++ b/torch_xla/csrc/ops/all.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, All::All(const Value& input, std::vector dimensions, bool keep_reduced_dimensions) - : Node(ir::OpKind(at::aten::all), {input}, + : Node(torch::lazy::OpKind(at::aten::all), {input}, NodeOutputShape(input, dimensions, keep_reduced_dimensions), /*num_outputs=*/1, torch::lazy::MHash(dimensions, keep_reduced_dimensions)), diff --git a/torch_xla/csrc/ops/amax.cpp b/torch_xla/csrc/ops/amax.cpp index 5ed3a73670e7..9fe98082f413 100644 --- a/torch_xla/csrc/ops/amax.cpp +++ b/torch_xla/csrc/ops/amax.cpp @@ -21,7 +21,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, } // namespace Amax::Amax(const Value& input, std::vector dimensions, bool keepdim) - : Node(ir::OpKind(at::aten::amax), {input}, + : Node(torch::lazy::OpKind(at::aten::amax), {input}, [&]() { return NodeOutputShape(input, dimensions, keepdim); }, /*num_outputs=*/1, torch::lazy::MHash(dimensions, keepdim)), dimensions_(std::move(dimensions)), diff --git a/torch_xla/csrc/ops/amin.cpp b/torch_xla/csrc/ops/amin.cpp index 19bdf73968aa..7f4e347578ce 100644 --- a/torch_xla/csrc/ops/amin.cpp +++ b/torch_xla/csrc/ops/amin.cpp @@ -21,7 +21,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, } // namespace Amin::Amin(const Value& input, std::vector dimensions, bool keepdim) - : Node(ir::OpKind(at::aten::amin), {input}, + : Node(torch::lazy::OpKind(at::aten::amin), {input}, [&]() { return NodeOutputShape(input, dimensions, keepdim); }, /*num_outputs=*/1, torch::lazy::MHash(dimensions, keepdim)), dimensions_(std::move(dimensions)), diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp index bab99cbd9b21..dcb45c6537b5 100644 --- a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp @@ -35,7 +35,8 @@ std::vector GetOperandList(absl::Span operands, AmpForachNonFiniteCheckAndUnscale::AmpForachNonFiniteCheckAndUnscale( const OpList& inputs, const Value& found_inf, const Value& inv_scale) - : Node(ir::OpKind(at::aten::_amp_foreach_non_finite_check_and_unscale_), + : Node(torch::lazy::OpKind( + at::aten::_amp_foreach_non_finite_check_and_unscale_), GetOperandList(inputs, found_inf, inv_scale), NodeOutputShape(inputs, found_inf), /*num_outputs=*/inputs.size() + 1) {} diff --git a/torch_xla/csrc/ops/amp_update_scale.cpp b/torch_xla/csrc/ops/amp_update_scale.cpp index a4f3fceb7101..b7e1d4cefae9 100644 --- a/torch_xla/csrc/ops/amp_update_scale.cpp +++ b/torch_xla/csrc/ops/amp_update_scale.cpp @@ -24,7 +24,7 @@ AmpUpdateScale::AmpUpdateScale(const Value& current_scale, const Value& found_inf, double scale_growth_factor, double scale_backoff_factor, int growth_interval) - : Node(ir::OpKind(at::aten::_amp_update_scale_), + : Node(torch::lazy::OpKind(at::aten::_amp_update_scale_), {current_scale, growth_tracker, found_inf}, NodeOutputShape(growth_tracker, current_scale), /*num_outputs=*/2), diff --git a/torch_xla/csrc/ops/any.cpp b/torch_xla/csrc/ops/any.cpp index 545554b44f2a..ea5d12424757 100644 --- a/torch_xla/csrc/ops/any.cpp +++ b/torch_xla/csrc/ops/any.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, Any::Any(const Value& input, std::vector dimensions, bool keep_reduced_dimensions) - : Node(ir::OpKind(at::aten::any), {input}, + : Node(torch::lazy::OpKind(at::aten::any), {input}, NodeOutputShape(input, dimensions, keep_reduced_dimensions), /*num_outputs=*/1, torch::lazy::MHash(dimensions, keep_reduced_dimensions)), diff --git a/torch_xla/csrc/ops/arg_max.cpp b/torch_xla/csrc/ops/arg_max.cpp index c5a263b8b421..9da14c01c381 100644 --- a/torch_xla/csrc/ops/arg_max.cpp +++ b/torch_xla/csrc/ops/arg_max.cpp @@ -20,7 +20,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t dim, bool keepdim) { } // namespace ArgMax::ArgMax(const Value& input, int64_t dim, bool keepdim) - : Node(ir::OpKind(at::aten::argmax), {input}, + : Node(torch::lazy::OpKind(at::aten::argmax), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, /*num_outputs=*/1, torch::lazy::MHash(dim, keepdim)), dim_(dim), diff --git a/torch_xla/csrc/ops/arg_min.cpp b/torch_xla/csrc/ops/arg_min.cpp index 9ba4d3a80678..410e53dbe7e0 100644 --- a/torch_xla/csrc/ops/arg_min.cpp +++ b/torch_xla/csrc/ops/arg_min.cpp @@ -20,7 +20,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t dim, bool keepdim) { } // namespace ArgMin::ArgMin(const Value& input, int64_t dim, bool keepdim) - : Node(ir::OpKind(at::aten::argmin), {input}, + : Node(torch::lazy::OpKind(at::aten::argmin), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, /*num_outputs=*/1, torch::lazy::MHash(dim, keepdim)), dim_(dim), diff --git a/torch_xla/csrc/ops/arithmetic_ir_ops.cpp b/torch_xla/csrc/ops/arithmetic_ir_ops.cpp index 6bca2bc68773..9d75f7cab520 100644 --- a/torch_xla/csrc/ops/arithmetic_ir_ops.cpp +++ b/torch_xla/csrc/ops/arithmetic_ir_ops.cpp @@ -16,7 +16,7 @@ NodePtr operator+(const Value& node1, const Value& node2) { return node.ReturnOp(XlaHelpers::PromotedAdd(op0, op1), loctx); }; return ops::GenericOp( - OpKind(at::aten::add), {node1, node2}, + torch::lazy::OpKind(at::aten::add), {node1, node2}, XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()), std::move(lower_fn)); } @@ -28,7 +28,7 @@ NodePtr operator-(const Value& node1, const Value& node2) { return node.ReturnOp(XlaHelpers::PromotedSub(op0, op1), loctx); }; return ops::GenericOp( - OpKind(at::aten::sub), {node1, node2}, + torch::lazy::OpKind(at::aten::sub), {node1, node2}, XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()), std::move(lower_fn)); } @@ -40,7 +40,7 @@ NodePtr operator*(const Value& node1, const Value& node2) { return node.ReturnOp(XlaHelpers::PromotedMul(op0, op1), loctx); }; return ops::GenericOp( - OpKind(at::aten::mul), {node1, node2}, + torch::lazy::OpKind(at::aten::mul), {node1, node2}, XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()), std::move(lower_fn)); } @@ -52,7 +52,7 @@ NodePtr operator/(const Value& node1, const Value& node2) { return node.ReturnOp(XlaHelpers::PromotedDiv(op0, op1), loctx); }; return ops::GenericOp( - OpKind(at::aten::div), {node1, node2}, + torch::lazy::OpKind(at::aten::div), {node1, node2}, XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()), std::move(lower_fn)); } diff --git a/torch_xla/csrc/ops/as_strided.cpp b/torch_xla/csrc/ops/as_strided.cpp index faf9a56fbd6a..7a8931995d88 100644 --- a/torch_xla/csrc/ops/as_strided.cpp +++ b/torch_xla/csrc/ops/as_strided.cpp @@ -45,7 +45,7 @@ xla::XlaOp LowerAsStrided(xla::XlaOp input, absl::Span size, AsStrided::AsStrided(const Value& input, std::vector size, std::vector stride, int64_t storage_offset) - : Node(ir::OpKind(at::aten::as_strided), {input}, + : Node(torch::lazy::OpKind(at::aten::as_strided), {input}, [&]() { return xla::ShapeUtil::MakeShape(input.shape().element_type(), size); diff --git a/torch_xla/csrc/ops/avg_pool_nd.cpp b/torch_xla/csrc/ops/avg_pool_nd.cpp index dc7e38e9aece..21827f6b9c4b 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd.cpp @@ -47,7 +47,7 @@ AvgPoolNd::AvgPoolNd(const Value& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, bool ceil_mode, bool count_include_pad) - : Node(ir::OpKind(AvgPoolNdSymbol(spatial_dim_count)), {input}, + : Node(torch::lazy::OpKind(AvgPoolNdSymbol(spatial_dim_count)), {input}, [&]() { return NodeOutputShape(input, spatial_dim_count, kernel_size, stride, padding, ceil_mode, diff --git a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp index e77ce124c5b1..eb70e5ad8b14 100644 --- a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp @@ -47,7 +47,8 @@ AvgPoolNdBackward::AvgPoolNdBackward( const Value& grad_output, const Value& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, bool ceil_mode, bool count_include_pad) - : Node(OpKind(AvgNdBackwardSymbol(spatial_dim_count)), {grad_output, input}, + : Node(torch::lazy::OpKind(AvgNdBackwardSymbol(spatial_dim_count)), + {grad_output, input}, [&]() { return NodeOutputShape(grad_output, input, spatial_dim_count, kernel_size, stride, padding, ceil_mode, diff --git a/torch_xla/csrc/ops/bernoulli.cpp b/torch_xla/csrc/ops/bernoulli.cpp index a64c73678afe..bb8b9959210d 100644 --- a/torch_xla/csrc/ops/bernoulli.cpp +++ b/torch_xla/csrc/ops/bernoulli.cpp @@ -10,7 +10,7 @@ namespace ops { Bernoulli::Bernoulli(const Value& probability, const Value& seed, xla::Shape shape) - : Node(ir::OpKind(at::aten::bernoulli), {probability, seed}, + : Node(torch::lazy::OpKind(at::aten::bernoulli), {probability, seed}, std::move(shape)) {} NodePtr Bernoulli::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/binary_cross_entropy.cpp b/torch_xla/csrc/ops/binary_cross_entropy.cpp index e1e2e7b583c5..713cf8049618 100644 --- a/torch_xla/csrc/ops/binary_cross_entropy.cpp +++ b/torch_xla/csrc/ops/binary_cross_entropy.cpp @@ -35,7 +35,7 @@ xla::Shape NodeOutputShape(const Value& logits, const Value& labels, BinaryCrossEntropy::BinaryCrossEntropy(const Value& logits, const Value& labels, const absl::optional& weight, ReductionMode reduction) - : Node(ir::OpKind(at::aten::binary_cross_entropy), + : Node(torch::lazy::OpKind(at::aten::binary_cross_entropy), xla::util::GetValuesVector({logits, labels}, {&weight}), [&]() { return NodeOutputShape(logits, labels, weight, reduction); }, /*num_outputs=*/1, diff --git a/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp b/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp index 628ebf59723a..7aa01e4c85f2 100644 --- a/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp +++ b/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp @@ -37,7 +37,7 @@ xla::Shape NodeOutputShape(const Value& grad_output, const Value& logits, BinaryCrossEntropyBackward::BinaryCrossEntropyBackward( const Value& grad_output, const Value& logits, const Value& labels, const absl::optional& weight, ReductionMode reduction) - : Node(ir::OpKind(at::aten::binary_cross_entropy_backward), + : Node(torch::lazy::OpKind(at::aten::binary_cross_entropy_backward), xla::util::GetValuesVector({grad_output, logits, labels}, {&weight}), [&]() { diff --git a/torch_xla/csrc/ops/bitwise_ir_ops.cpp b/torch_xla/csrc/ops/bitwise_ir_ops.cpp index f410d54ef349..6b27aba8ab41 100644 --- a/torch_xla/csrc/ops/bitwise_ir_ops.cpp +++ b/torch_xla/csrc/ops/bitwise_ir_ops.cpp @@ -23,7 +23,7 @@ Value BitwiseAnd(const Value& node1, const Value& node2) { [](xla::XlaOp lhs, xla::XlaOp rhs) { return lhs & rhs; }); }; return GenericOp( - OpKind(at::aten::__and__), {node1, node2}, + torch::lazy::OpKind(at::aten::__and__), {node1, node2}, [&]() { return InferOutputShape({node1.shape(), node2.shape()}, shape_fn); }, @@ -44,7 +44,7 @@ Value BitwiseOr(const Value& node1, const Value& node2) { [](xla::XlaOp lhs, xla::XlaOp rhs) { return lhs | rhs; }); }; return GenericOp( - OpKind(at::aten::__or__), {node1, node2}, + torch::lazy::OpKind(at::aten::__or__), {node1, node2}, [&]() { return InferOutputShape({node1.shape(), node2.shape()}, shape_fn); }, @@ -65,7 +65,7 @@ Value BitwiseXor(const Value& node1, const Value& node2) { [](xla::XlaOp lhs, xla::XlaOp rhs) { return lhs ^ rhs; }); }; return GenericOp( - OpKind(at::aten::__xor__), {node1, node2}, + torch::lazy::OpKind(at::aten::__xor__), {node1, node2}, [&]() { return InferOutputShape({node1.shape(), node2.shape()}, shape_fn); }, diff --git a/torch_xla/csrc/ops/cat.cpp b/torch_xla/csrc/ops/cat.cpp index 72e96b5bceac..eb09edd9d2f0 100644 --- a/torch_xla/csrc/ops/cat.cpp +++ b/torch_xla/csrc/ops/cat.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(absl::Span values, int64_t dim) { } // namespace Cat::Cat(absl::Span values, int64_t dim) - : Node(ir::OpKind(at::aten::cat), values, + : Node(torch::lazy::OpKind(at::aten::cat), values, [&]() { return NodeOutputShape(values, dim); }, /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/cholesky.cpp b/torch_xla/csrc/ops/cholesky.cpp index 08bfe04d1174..f6d4ba95c62b 100644 --- a/torch_xla/csrc/ops/cholesky.cpp +++ b/torch_xla/csrc/ops/cholesky.cpp @@ -9,7 +9,7 @@ namespace ir { namespace ops { Cholesky::Cholesky(const Value& input, bool lower) - : Node(ir::OpKind(at::aten::cholesky), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::cholesky), {input}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(lower)), lower_(lower) {} diff --git a/torch_xla/csrc/ops/constant.cpp b/torch_xla/csrc/ops/constant.cpp index 8442cf5c9450..f87de1f0db27 100644 --- a/torch_xla/csrc/ops/constant.cpp +++ b/torch_xla/csrc/ops/constant.cpp @@ -10,8 +10,8 @@ namespace ir { namespace ops { Constant::Constant(xla::Literal value) - : Node(OpKind(at::prim::Constant), value.shape(), /*num_outputs=*/1, - absl::Hash{}(value)), + : Node(torch::lazy::OpKind(at::prim::Constant), value.shape(), + /*num_outputs=*/1, absl::Hash{}(value)), value_(std::move(value)) {} std::string Constant::ToString() const { diff --git a/torch_xla/csrc/ops/constant_pad_nd.cpp b/torch_xla/csrc/ops/constant_pad_nd.cpp index ef60c960e664..2169a1de332b 100644 --- a/torch_xla/csrc/ops/constant_pad_nd.cpp +++ b/torch_xla/csrc/ops/constant_pad_nd.cpp @@ -35,7 +35,7 @@ xla::Shape NodeOutputShape(const Value& input, const at::Scalar& value, ConstantPadNd::ConstantPadNd(const Value& input, std::vector pad, const at::Scalar& value) - : Node(ir::OpKind(at::aten::constant_pad_nd), {input}, + : Node(torch::lazy::OpKind(at::aten::constant_pad_nd), {input}, [&]() { return NodeOutputShape(input, value, pad); }, /*num_outputs=*/1, torch::lazy::MHash(pad, ScalarHash(value))), pad_(std::move(pad)), diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp index d1144a0a33af..4e44c0733eba 100644 --- a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp @@ -40,7 +40,7 @@ ConvolutionBackwardOverrideable::ConvolutionBackwardOverrideable( std::vector stride, std::vector padding, std::vector dilation, bool transposed, std::vector output_padding, int64_t groups) - : Node(ir::OpKind(at::aten::convolution_backward_overrideable), + : Node(torch::lazy::OpKind(at::aten::convolution_backward_overrideable), {grad_output, input, weight}, [&]() { return NodeOutputShape(grad_output, input, weight, stride, padding, diff --git a/torch_xla/csrc/ops/convolution_overrideable.cpp b/torch_xla/csrc/ops/convolution_overrideable.cpp index af9d6ad01aa8..412c999b2a9d 100644 --- a/torch_xla/csrc/ops/convolution_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_overrideable.cpp @@ -36,7 +36,7 @@ ConvolutionOverrideable::ConvolutionOverrideable( std::vector stride, std::vector padding, std::vector dilation, bool transposed, std::vector output_padding, int64_t groups) - : Node(ir::OpKind(at::aten::convolution_overrideable), + : Node(torch::lazy::OpKind(at::aten::convolution_overrideable), {input, weight, bias}, [&]() { return NodeOutputShape(input, weight, stride, padding, dilation, @@ -56,7 +56,8 @@ ConvolutionOverrideable::ConvolutionOverrideable( const Value& input, const Value& weight, std::vector stride, std::vector padding, std::vector dilation, bool transposed, std::vector output_padding, int64_t groups) - : Node(ir::OpKind(at::aten::convolution_overrideable), {input, weight}, + : Node(torch::lazy::OpKind(at::aten::convolution_overrideable), + {input, weight}, [&]() { return NodeOutputShape(input, weight, stride, padding, dilation, transposed, output_padding, groups); diff --git a/torch_xla/csrc/ops/cumprod.cpp b/torch_xla/csrc/ops/cumprod.cpp index 775b5b70b637..d86c97660597 100644 --- a/torch_xla/csrc/ops/cumprod.cpp +++ b/torch_xla/csrc/ops/cumprod.cpp @@ -39,7 +39,7 @@ xla::Shape NodeOutputShape(const Value& input, CumProd::CumProd(const Value& input, int64_t dim, c10::optional dtype) - : Node(ir::OpKind(at::aten::cumprod), {input}, + : Node(torch::lazy::OpKind(at::aten::cumprod), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, torch::lazy::MHash(dim, torch::lazy::OptionalOr(dtype, -1))), diff --git a/torch_xla/csrc/ops/cumsum.cpp b/torch_xla/csrc/ops/cumsum.cpp index 246be8270ddd..cfa12f76a8f2 100644 --- a/torch_xla/csrc/ops/cumsum.cpp +++ b/torch_xla/csrc/ops/cumsum.cpp @@ -38,7 +38,7 @@ xla::Shape NodeOutputShape(const Value& input, CumSum::CumSum(const Value& input, int64_t dim, c10::optional dtype) - : Node(ir::OpKind(at::aten::cumsum), {input}, + : Node(torch::lazy::OpKind(at::aten::cumsum), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, torch::lazy::MHash(dim, torch::lazy::OptionalOr(dtype, -1))), diff --git a/torch_xla/csrc/ops/diagonal.cpp b/torch_xla/csrc/ops/diagonal.cpp index 8da703cf3625..9174fdb5e254 100644 --- a/torch_xla/csrc/ops/diagonal.cpp +++ b/torch_xla/csrc/ops/diagonal.cpp @@ -13,7 +13,7 @@ namespace ops { Diagonal::Diagonal(const Value& input, int64_t offset, int64_t dim1, int64_t dim2) - : Node(ir::OpKind(at::aten::diagonal), {input}, + : Node(torch::lazy::OpKind(at::aten::diagonal), {input}, [&]() { return MakeDiagonalShape(input.shape(), offset, dim1, dim2); }, diff --git a/torch_xla/csrc/ops/discrete_uniform.cpp b/torch_xla/csrc/ops/discrete_uniform.cpp index 1707433ff36a..43f37d8a9b9b 100644 --- a/torch_xla/csrc/ops/discrete_uniform.cpp +++ b/torch_xla/csrc/ops/discrete_uniform.cpp @@ -12,7 +12,7 @@ namespace ops { DiscreteUniform::DiscreteUniform(const Value& from, const Value& to, const Value& seed, const xla::Shape& rng_shape) - : Node(ir::OpKind(at::aten::random), {from, to, seed}, rng_shape, + : Node(torch::lazy::OpKind(at::aten::random), {from, to, seed}, rng_shape, /*num_outputs=*/1, torch::lazy::Hash(rng_shape)) {} NodePtr DiscreteUniform::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/expand.cpp b/torch_xla/csrc/ops/expand.cpp index cb38bcd1f8c3..8bbacb4efaea 100644 --- a/torch_xla/csrc/ops/expand.cpp +++ b/torch_xla/csrc/ops/expand.cpp @@ -22,7 +22,7 @@ xla::Shape NodeOutputShape(const Value& input, } // namespace Expand::Expand(const Value& input, std::vector size) - : Node(ir::OpKind(at::aten::expand), {input}, + : Node(torch::lazy::OpKind(at::aten::expand), {input}, [&]() { return NodeOutputShape(input, size); }, /*num_outputs=*/1, torch::lazy::MHash(size)), size_(std::move(size)) {} diff --git a/torch_xla/csrc/ops/exponential.cpp b/torch_xla/csrc/ops/exponential.cpp index 8238b5a7425d..505458f7056a 100644 --- a/torch_xla/csrc/ops/exponential.cpp +++ b/torch_xla/csrc/ops/exponential.cpp @@ -10,7 +10,7 @@ namespace ops { Exponential::Exponential(const Value& lambda, const Value& seed, xla::Shape shape) - : Node(ir::OpKind(at::aten::exponential), {lambda, seed}, + : Node(torch::lazy::OpKind(at::aten::exponential), {lambda, seed}, std::move(shape)) {} NodePtr Exponential::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/flip.cpp b/torch_xla/csrc/ops/flip.cpp index df67cc4a7b9f..ffe97fcac0c9 100644 --- a/torch_xla/csrc/ops/flip.cpp +++ b/torch_xla/csrc/ops/flip.cpp @@ -8,7 +8,7 @@ namespace ir { namespace ops { Flip::Flip(const Value& input, std::vector dims) - : Node(ir::OpKind(at::aten::flip), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::flip), {input}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(dims)), dims_(std::move(dims)) {} diff --git a/torch_xla/csrc/ops/gather.cpp b/torch_xla/csrc/ops/gather.cpp index 42abb2d8162a..0e6c8304a1dd 100644 --- a/torch_xla/csrc/ops/gather.cpp +++ b/torch_xla/csrc/ops/gather.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input, const Value& index, } // namespace Gather::Gather(const Value& input, int64_t dim, const Value& index) - : Node(ir::OpKind(at::aten::gather), {input, index}, + : Node(torch::lazy::OpKind(at::aten::gather), {input, index}, [&]() { return NodeOutputShape(input, index, dim); }, /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/generic.cpp b/torch_xla/csrc/ops/generic.cpp index 0a774fa15a6c..7e6b61eb61da 100644 --- a/torch_xla/csrc/ops/generic.cpp +++ b/torch_xla/csrc/ops/generic.cpp @@ -6,29 +6,29 @@ namespace torch_xla { namespace ir { namespace ops { -Generic::Generic(OpKind op, absl::Span operands, xla::Shape shape, - LowerFn lower_fn, size_t num_outputs, +Generic::Generic(torch::lazy::OpKind op, absl::Span operands, + xla::Shape shape, LowerFn lower_fn, size_t num_outputs, torch::lazy::hash_t hash_seed) : Node(std::move(op), operands, std::move(shape), num_outputs, hash_seed), lower_fn_(std::move(lower_fn)), hash_seed_(hash_seed) {} -Generic::Generic(OpKind op, absl::Span operands, +Generic::Generic(torch::lazy::OpKind op, absl::Span operands, const std::function& shape_fn, LowerFn lower_fn, size_t num_outputs, torch::lazy::hash_t hash_seed) : Node(std::move(op), operands, shape_fn, num_outputs, hash_seed), lower_fn_(std::move(lower_fn)), hash_seed_(hash_seed) {} -Generic::Generic(OpKind op, xla::Shape shape, LowerFn lower_fn, +Generic::Generic(torch::lazy::OpKind op, xla::Shape shape, LowerFn lower_fn, size_t num_outputs, torch::lazy::hash_t hash_seed) : Node(std::move(op), std::move(shape), num_outputs, hash_seed), lower_fn_(std::move(lower_fn)), hash_seed_(hash_seed) {} NodePtr Generic::Clone(OpList operands) const { - return MakeNode(op(), operands, shape(), lower_fn_, num_outputs(), - hash_seed_); + return ir::MakeNode(op(), operands, shape(), lower_fn_, + num_outputs(), hash_seed_); } XlaOpVector Generic::Lower(LoweringContext* loctx) const { diff --git a/torch_xla/csrc/ops/generic.h b/torch_xla/csrc/ops/generic.h index f426426e5b19..e28bfac298e9 100644 --- a/torch_xla/csrc/ops/generic.h +++ b/torch_xla/csrc/ops/generic.h @@ -15,17 +15,17 @@ class Generic : public Node { public: using LowerFn = std::function; - Generic(OpKind op, absl::Span operands, xla::Shape shape, - LowerFn lower_fn, size_t num_outputs = 1, + Generic(torch::lazy::OpKind op, absl::Span operands, + xla::Shape shape, LowerFn lower_fn, size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); - Generic(OpKind op, absl::Span operands, + Generic(torch::lazy::OpKind op, absl::Span operands, const std::function& shape_fn, LowerFn lower_fn, size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); - Generic(OpKind op, xla::Shape shape, LowerFn lower_fn, size_t num_outputs, - torch::lazy::hash_t hash_seed); + Generic(torch::lazy::OpKind op, xla::Shape shape, LowerFn lower_fn, + size_t num_outputs, torch::lazy::hash_t hash_seed); NodePtr Clone(OpList operands) const override; diff --git a/torch_xla/csrc/ops/hardshrink.cpp b/torch_xla/csrc/ops/hardshrink.cpp index 06052ec7e9ef..17acd7ea3c9c 100644 --- a/torch_xla/csrc/ops/hardshrink.cpp +++ b/torch_xla/csrc/ops/hardshrink.cpp @@ -10,7 +10,7 @@ namespace ir { namespace ops { Hardshrink::Hardshrink(const Value& input, const at::Scalar& lambda) - : Node(OpKind(at::aten::hardshrink), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::hardshrink), {input}, input.shape(), /*num_outputs=*/1, ScalarHash(lambda)), lambda_(std::move(lambda)) {} diff --git a/torch_xla/csrc/ops/hardtanh_backward.cpp b/torch_xla/csrc/ops/hardtanh_backward.cpp index 642a64cb6922..b4c1d69537d3 100644 --- a/torch_xla/csrc/ops/hardtanh_backward.cpp +++ b/torch_xla/csrc/ops/hardtanh_backward.cpp @@ -11,8 +11,8 @@ namespace ops { HardtanhBackward::HardtanhBackward(const Value& grad_output, const Value& input, const at::Scalar& min_val, const at::Scalar& max_val) - : Node(OpKind(at::aten::hardtanh_backward), {grad_output, input}, - grad_output.shape(), /*num_outputs=*/1, + : Node(torch::lazy::OpKind(at::aten::hardtanh_backward), + {grad_output, input}, grad_output.shape(), /*num_outputs=*/1, torch::lazy::MHash(ScalarHash(min_val), ScalarHash(max_val))), min_val_(min_val), max_val_(max_val) {} diff --git a/torch_xla/csrc/ops/index_get.cpp b/torch_xla/csrc/ops/index_get.cpp index 531ac1735059..eb24b82891b0 100644 --- a/torch_xla/csrc/ops/index_get.cpp +++ b/torch_xla/csrc/ops/index_get.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& base, const Value& indices, IndexGet::IndexGet(const ir::Value& base, const ir::Value& indices, int64_t start_dim) - : Node(OpKind(at::aten::index), {base, indices}, + : Node(torch::lazy::OpKind(at::aten::index), {base, indices}, [&]() { return NodeOutputShape(base, indices, start_dim); }, /*num_outputs=*/1, torch::lazy::MHash(start_dim)), start_dim_(start_dim) {} diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 0247e1cfe707..f4d4a36a1d63 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -165,7 +165,7 @@ ir::NodePtr IndexFillOp(const ir::Value& buffer, int64_t dim, }; ir::Value index_rank1 = EnsureRank1(index); return ir::ops::GenericOp( - ir::OpKind(at::aten::index_fill), {buffer, index_rank1, value}, + torch::lazy::OpKind(at::aten::index_fill), {buffer, index_rank1, value}, [&]() { return ir::ops::InferOutputShape( {buffer.shape(), index_rank1.shape(), value.shape()}, @@ -190,7 +190,7 @@ ir::NodePtr IndexAddOp(const ir::Value& buffer, int64_t dim, }; ir::Value index_rank1 = EnsureRank1(index); return ir::ops::GenericOp( - ir::OpKind(at::aten::index_add), {buffer, index_rank1, source}, + torch::lazy::OpKind(at::aten::index_add), {buffer, index_rank1, source}, [&]() { return ir::ops::InferOutputShape( {buffer.shape(), index_rank1.shape(), source.shape()}, @@ -215,7 +215,7 @@ ir::NodePtr IndexCopyOp(const ir::Value& buffer, int64_t dim, }; ir::Value index_rank1 = EnsureRank1(index); return ir::ops::GenericOp( - ir::OpKind(at::aten::index_copy), {buffer, index_rank1, source}, + torch::lazy::OpKind(at::aten::index_copy), {buffer, index_rank1, source}, [&]() { return ir::ops::InferOutputShape( {buffer.shape(), index_rank1.shape(), source.shape()}, diff --git a/torch_xla/csrc/ops/index_put.cpp b/torch_xla/csrc/ops/index_put.cpp index 82d69fd9ba71..3a42af701049 100644 --- a/torch_xla/csrc/ops/index_put.cpp +++ b/torch_xla/csrc/ops/index_put.cpp @@ -9,7 +9,8 @@ namespace ops { IndexPut::IndexPut(const ir::Value& base, const ir::Value& indices, int64_t start_dim, const ir::Value& values, bool accumulate) - : Node(OpKind(at::aten::index_put), {base, indices, values}, base.shape(), + : Node(torch::lazy::OpKind(at::aten::index_put), {base, indices, values}, + base.shape(), /*num_outputs=*/1, torch::lazy::MHash(start_dim, accumulate)), start_dim_(start_dim), accumulate_(accumulate) {} diff --git a/torch_xla/csrc/ops/index_select.cpp b/torch_xla/csrc/ops/index_select.cpp index d46621fd8a7d..9e715254d907 100644 --- a/torch_xla/csrc/ops/index_select.cpp +++ b/torch_xla/csrc/ops/index_select.cpp @@ -22,7 +22,7 @@ xla::Shape NodeOutputShape(const Value& input, const Value& index, } // namespace IndexSelect::IndexSelect(const Value& input, int64_t dim, const Value& index) - : Node(ir::OpKind(at::aten::index_select), {input, index}, + : Node(torch::lazy::OpKind(at::aten::index_select), {input, index}, [&]() { return NodeOutputShape(input, index, dim); }, /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/kth_value.cpp b/torch_xla/csrc/ops/kth_value.cpp index 27d398545dac..967993daf686 100644 --- a/torch_xla/csrc/ops/kth_value.cpp +++ b/torch_xla/csrc/ops/kth_value.cpp @@ -22,7 +22,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t k, int64_t dim, } // namespace KthValue::KthValue(const Value& input, int64_t k, int64_t dim, bool keepdim) - : Node(ir::OpKind(at::aten::kthvalue), {input}, + : Node(torch::lazy::OpKind(at::aten::kthvalue), {input}, [&]() { return NodeOutputShape(input, k, dim, keepdim); }, /*num_outputs=*/2, torch::lazy::MHash(k, dim, keepdim)), k_(k), diff --git a/torch_xla/csrc/ops/l1_loss.cpp b/torch_xla/csrc/ops/l1_loss.cpp index 4f58bd64b7e2..491d2ea4f495 100644 --- a/torch_xla/csrc/ops/l1_loss.cpp +++ b/torch_xla/csrc/ops/l1_loss.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(const Value& input, const Value& target, } // namespace L1Loss::L1Loss(const Value& input, const Value& target, ReductionMode reduction) - : Node(ir::OpKind(at::aten::l1_loss), {input, target}, + : Node(torch::lazy::OpKind(at::aten::l1_loss), {input, target}, [&]() { return NodeOutputShape(input, target, reduction); }, /*num_outputs=*/1, torch::lazy::MHash(torch::lazy::GetEnumValue(reduction))), diff --git a/torch_xla/csrc/ops/l1_loss_backward.cpp b/torch_xla/csrc/ops/l1_loss_backward.cpp index a610cf470d4c..c8060825e341 100644 --- a/torch_xla/csrc/ops/l1_loss_backward.cpp +++ b/torch_xla/csrc/ops/l1_loss_backward.cpp @@ -25,7 +25,8 @@ xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, L1LossBackward::L1LossBackward(const Value& grad_output, const Value& input, const Value& target, ReductionMode reduction) - : Node(ir::OpKind(at::aten::l1_loss_backward), {grad_output, input, target}, + : Node(torch::lazy::OpKind(at::aten::l1_loss_backward), + {grad_output, input, target}, [&]() { return NodeOutputShape(grad_output, input, target, reduction); }, diff --git a/torch_xla/csrc/ops/leaky_relu.cpp b/torch_xla/csrc/ops/leaky_relu.cpp index 92eacb381213..655d96ba189e 100644 --- a/torch_xla/csrc/ops/leaky_relu.cpp +++ b/torch_xla/csrc/ops/leaky_relu.cpp @@ -8,7 +8,7 @@ namespace ir { namespace ops { LeakyRelu::LeakyRelu(const Value& input, double negative_slope) - : Node(ir::OpKind(at::aten::leaky_relu), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::leaky_relu), {input}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(negative_slope)), negative_slope_(negative_slope) {} diff --git a/torch_xla/csrc/ops/leaky_relu_backward.cpp b/torch_xla/csrc/ops/leaky_relu_backward.cpp index f30001c78338..176bb9f62d2f 100644 --- a/torch_xla/csrc/ops/leaky_relu_backward.cpp +++ b/torch_xla/csrc/ops/leaky_relu_backward.cpp @@ -9,8 +9,8 @@ namespace ops { LeakyReluBackward::LeakyReluBackward(const Value& grad_output, const Value& input, double negative_slope) - : Node(ir::OpKind(at::aten::leaky_relu_backward), {grad_output, input}, - input.shape(), + : Node(torch::lazy::OpKind(at::aten::leaky_relu_backward), + {grad_output, input}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(negative_slope)), negative_slope_(negative_slope) {} diff --git a/torch_xla/csrc/ops/linspace.cpp b/torch_xla/csrc/ops/linspace.cpp index 287723e178b3..c4b02eaada30 100644 --- a/torch_xla/csrc/ops/linspace.cpp +++ b/torch_xla/csrc/ops/linspace.cpp @@ -9,7 +9,7 @@ namespace ir { namespace ops { Linspace::Linspace(const Value& start, const Value& end, int64_t steps) - : Node(ir::OpKind(at::aten::linspace), {start, end}, + : Node(torch::lazy::OpKind(at::aten::linspace), {start, end}, [&]() { xla::PrimitiveType dtype = XlaHelpers::PromoteType( start.shape().element_type(), end.shape().element_type()); diff --git a/torch_xla/csrc/ops/log_softmax.cpp b/torch_xla/csrc/ops/log_softmax.cpp index c605884aa7db..3a5813271e5b 100644 --- a/torch_xla/csrc/ops/log_softmax.cpp +++ b/torch_xla/csrc/ops/log_softmax.cpp @@ -31,7 +31,7 @@ xla::Shape NodeOutputShape(const Value& input, LogSoftmax::LogSoftmax(const Value& input, int64_t dim, c10::optional dtype) - : Node(ir::OpKind(at::aten::log_softmax), {input}, + : Node(torch::lazy::OpKind(at::aten::log_softmax), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, torch::lazy::MHash(dim, torch::lazy::OptionalOr(dtype, -1))), diff --git a/torch_xla/csrc/ops/log_softmax_backward.cpp b/torch_xla/csrc/ops/log_softmax_backward.cpp index eaff517bf0b1..d787f1cc5f5b 100644 --- a/torch_xla/csrc/ops/log_softmax_backward.cpp +++ b/torch_xla/csrc/ops/log_softmax_backward.cpp @@ -11,7 +11,7 @@ namespace ops { LogSoftmaxBackward::LogSoftmaxBackward(const Value& grad_output, const Value& output, int64_t dim) - : Node(ir::OpKind(at::aten::_log_softmax_backward_data), + : Node(torch::lazy::OpKind(at::aten::_log_softmax_backward_data), {grad_output, output}, grad_output.shape(), /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/logsumexp.cpp b/torch_xla/csrc/ops/logsumexp.cpp index c99418d6179a..cc8f97b585a7 100644 --- a/torch_xla/csrc/ops/logsumexp.cpp +++ b/torch_xla/csrc/ops/logsumexp.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, Logsumexp::Logsumexp(const Value& input, std::vector dimensions, bool keep_reduced_dimensions) - : Node(ir::OpKind(at::aten::logsumexp), {input}, + : Node(torch::lazy::OpKind(at::aten::logsumexp), {input}, [&]() { return NodeOutputShape(input, dimensions, keep_reduced_dimensions); }, diff --git a/torch_xla/csrc/ops/masked_fill.cpp b/torch_xla/csrc/ops/masked_fill.cpp index c14275b03da8..28c4deaff3d7 100644 --- a/torch_xla/csrc/ops/masked_fill.cpp +++ b/torch_xla/csrc/ops/masked_fill.cpp @@ -11,7 +11,8 @@ namespace ops { MaskedFill::MaskedFill(const Value& input, const Value& mask, const at::Scalar& value) - : Node(OpKind(at::aten::masked_fill), {input, mask}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::masked_fill), {input, mask}, + input.shape(), /*num_outputs=*/1, ScalarHash(value)), value_(std::move(value)) {} diff --git a/torch_xla/csrc/ops/masked_scatter.cpp b/torch_xla/csrc/ops/masked_scatter.cpp index 020aa52aa277..50955996c569 100644 --- a/torch_xla/csrc/ops/masked_scatter.cpp +++ b/torch_xla/csrc/ops/masked_scatter.cpp @@ -9,7 +9,7 @@ namespace ops { MaskedScatter::MaskedScatter(const Value& input, const Value& mask, const Value& source) - : Node(ir::OpKind(at::aten::masked_scatter), {input, mask, source}, + : Node(torch::lazy::OpKind(at::aten::masked_scatter), {input, mask, source}, input.shape(), /*num_outputs=*/1) {} diff --git a/torch_xla/csrc/ops/masked_select.cpp b/torch_xla/csrc/ops/masked_select.cpp index 740e8ac8920e..936ebf553bbd 100644 --- a/torch_xla/csrc/ops/masked_select.cpp +++ b/torch_xla/csrc/ops/masked_select.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input) { } // namespace MaskedSelect::MaskedSelect(const Value& input, const Value& mask) - : Node(ir::OpKind(at::aten::masked_select), {input, mask}, + : Node(torch::lazy::OpKind(at::aten::masked_select), {input, mask}, NodeOutputShape(input), /*num_outputs=*/2) {} diff --git a/torch_xla/csrc/ops/max_in_dim.cpp b/torch_xla/csrc/ops/max_in_dim.cpp index 34900b5ccfa9..1d5cbd57b85f 100644 --- a/torch_xla/csrc/ops/max_in_dim.cpp +++ b/torch_xla/csrc/ops/max_in_dim.cpp @@ -22,7 +22,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t dim, bool keepdim) { } // namespace MaxInDim::MaxInDim(const Value& input, int64_t dim, bool keepdim) - : Node(ir::OpKind(at::aten::max), {input}, + : Node(torch::lazy::OpKind(at::aten::max), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, /*num_outputs=*/2, torch::lazy::MHash(dim, keepdim)), dim_(dim), diff --git a/torch_xla/csrc/ops/max_pool_nd.cpp b/torch_xla/csrc/ops/max_pool_nd.cpp index 130b7211ccf1..9c0a13f3fc28 100644 --- a/torch_xla/csrc/ops/max_pool_nd.cpp +++ b/torch_xla/csrc/ops/max_pool_nd.cpp @@ -43,7 +43,7 @@ MaxPoolNd::MaxPoolNd(const Value& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, bool ceil_mode) - : Node(ir::OpKind(MaxPoolNdSymbol(spatial_dim_count)), {input}, + : Node(torch::lazy::OpKind(MaxPoolNdSymbol(spatial_dim_count)), {input}, [&]() { return NodeOutputShape(input, spatial_dim_count, kernel_size, stride, padding, ceil_mode); diff --git a/torch_xla/csrc/ops/max_pool_nd_backward.cpp b/torch_xla/csrc/ops/max_pool_nd_backward.cpp index b2fde1b14d9f..48fcf80e306c 100644 --- a/torch_xla/csrc/ops/max_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_pool_nd_backward.cpp @@ -44,7 +44,7 @@ MaxPoolNdBackward::MaxPoolNdBackward( const Value& grad_output, const Value& input, int64_t spatial_dim_count, std::vector kernel_size, std::vector stride, std::vector padding, bool ceil_mode) - : Node(ir::OpKind(MaxPoolNdBackwardSymbol(spatial_dim_count)), + : Node(torch::lazy::OpKind(MaxPoolNdBackwardSymbol(spatial_dim_count)), {grad_output, input}, [&]() { return NodeOutputShape(grad_output, input, spatial_dim_count, diff --git a/torch_xla/csrc/ops/max_unpool_nd.cpp b/torch_xla/csrc/ops/max_unpool_nd.cpp index cda03529399a..cd4a91c55fda 100644 --- a/torch_xla/csrc/ops/max_unpool_nd.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd.cpp @@ -35,7 +35,8 @@ c10::Symbol MaxUnpoolNdSymbol(int64_t spatial_dim_count) { MaxUnpoolNd::MaxUnpoolNd(const Value& input, const Value& indices, std::vector output_size) - : Node(ir::OpKind(MaxUnpoolNdSymbol(output_size.size())), {input, indices}, + : Node(torch::lazy::OpKind(MaxUnpoolNdSymbol(output_size.size())), + {input, indices}, [&]() { return NodeOutputShape(input, indices, output_size); }, /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} diff --git a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp index 7681a9522cf4..0f4eb189235b 100644 --- a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp @@ -39,7 +39,7 @@ MaxUnpoolNdBackward::MaxUnpoolNdBackward(const Value& grad_output, const Value& input, const Value& indices, std::vector output_size) - : Node(ir::OpKind(MaxUnpoolNdBackwardSymbol(output_size.size())), + : Node(torch::lazy::OpKind(MaxUnpoolNdBackwardSymbol(output_size.size())), {grad_output, input, indices}, [&]() { return NodeOutputShape(grad_output, input, indices, output_size); diff --git a/torch_xla/csrc/ops/mean.cpp b/torch_xla/csrc/ops/mean.cpp index 0716977e5ef8..7930625b2a81 100644 --- a/torch_xla/csrc/ops/mean.cpp +++ b/torch_xla/csrc/ops/mean.cpp @@ -38,7 +38,7 @@ xla::Shape NodeOutputShape(const Value& input, Mean::Mean(const Value& input, std::vector dimensions, bool keep_reduced_dimensions, c10::optional dtype) - : Node(ir::OpKind(at::aten::mean), {input}, + : Node(torch::lazy::OpKind(at::aten::mean), {input}, [&]() { return NodeOutputShape(input, dimensions, keep_reduced_dimensions, dtype); diff --git a/torch_xla/csrc/ops/min_in_dim.cpp b/torch_xla/csrc/ops/min_in_dim.cpp index a3c393c4027a..463f7e9388d2 100644 --- a/torch_xla/csrc/ops/min_in_dim.cpp +++ b/torch_xla/csrc/ops/min_in_dim.cpp @@ -22,7 +22,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t dim, bool keepdim) { } // namespace MinInDim::MinInDim(const Value& input, int64_t dim, bool keepdim) - : Node(ir::OpKind(at::aten::min), {input}, + : Node(torch::lazy::OpKind(at::aten::min), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, /*num_outputs=*/2, torch::lazy::MHash(dim, keepdim)), dim_(dim), diff --git a/torch_xla/csrc/ops/mse_loss.cpp b/torch_xla/csrc/ops/mse_loss.cpp index 04276b5d5fb7..80d0811396b7 100644 --- a/torch_xla/csrc/ops/mse_loss.cpp +++ b/torch_xla/csrc/ops/mse_loss.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, const Value& target, MseLoss::MseLoss(const Value& input, const Value& target, ReductionMode reduction) - : Node(ir::OpKind(at::aten::mse_loss), {input, target}, + : Node(torch::lazy::OpKind(at::aten::mse_loss), {input, target}, [&]() { return NodeOutputShape(input, target, reduction); }, /*num_outputs=*/1, torch::lazy::MHash(torch::lazy::GetEnumValue(reduction))), diff --git a/torch_xla/csrc/ops/mse_loss_backward.cpp b/torch_xla/csrc/ops/mse_loss_backward.cpp index 94130e7953d3..0c78c0e865f0 100644 --- a/torch_xla/csrc/ops/mse_loss_backward.cpp +++ b/torch_xla/csrc/ops/mse_loss_backward.cpp @@ -27,7 +27,7 @@ xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, MseLossBackward::MseLossBackward(const Value& grad_output, const Value& input, const Value& target, ReductionMode reduction) - : Node(ir::OpKind(at::aten::mse_loss_backward), + : Node(torch::lazy::OpKind(at::aten::mse_loss_backward), {grad_output, input, target}, [&]() { return NodeOutputShape(grad_output, input, target, reduction); diff --git a/torch_xla/csrc/ops/native_batch_norm_backward.cpp b/torch_xla/csrc/ops/native_batch_norm_backward.cpp index 5d67164c192d..9dfe4affeb76 100644 --- a/torch_xla/csrc/ops/native_batch_norm_backward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_backward.cpp @@ -32,7 +32,7 @@ xla::Shape NodeOutputShape(const Value& grad_out, const Value& input, NativeBatchNormBackward::NativeBatchNormBackward( const Value& grad_out, const Value& input, const Value& weight, const Value& save_mean, const Value& save_invstd, bool training, double eps) - : Node(ir::OpKind(at::aten::native_batch_norm_backward), + : Node(torch::lazy::OpKind(at::aten::native_batch_norm_backward), {grad_out, input, weight, save_mean, save_invstd}, [&]() { return NodeOutputShape(grad_out, input, weight, save_mean, diff --git a/torch_xla/csrc/ops/native_batch_norm_forward.cpp b/torch_xla/csrc/ops/native_batch_norm_forward.cpp index 7d695be2e232..a7f899bd75e9 100644 --- a/torch_xla/csrc/ops/native_batch_norm_forward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_forward.cpp @@ -56,7 +56,7 @@ NativeBatchNormForward::NativeBatchNormForward(const Value& input, const Value& running_mean, const Value& running_var, bool training, double eps) - : Node(ir::OpKind(at::aten::native_batch_norm), + : Node(torch::lazy::OpKind(at::aten::native_batch_norm), {input, weight, bias, running_mean, running_var}, [&]() { return NodeOutputShape(input, weight, bias, running_mean, diff --git a/torch_xla/csrc/ops/nll_loss.cpp b/torch_xla/csrc/ops/nll_loss.cpp index a52da291c0bd..298ddf355164 100644 --- a/torch_xla/csrc/ops/nll_loss.cpp +++ b/torch_xla/csrc/ops/nll_loss.cpp @@ -37,7 +37,7 @@ xla::Shape NodeOutputShape(const Value& logits, const Value& labels, NllLoss::NllLoss(const Value& logits, const Value& labels, const absl::optional& weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss), + : Node(torch::lazy::OpKind(at::aten::nll_loss), xla::util::GetValuesVector({logits, labels}, {&weight}), [&]() { return NodeOutputShape(logits, labels, weight, reduction, diff --git a/torch_xla/csrc/ops/nll_loss2d.cpp b/torch_xla/csrc/ops/nll_loss2d.cpp index 657942b3eee4..2655c2437941 100644 --- a/torch_xla/csrc/ops/nll_loss2d.cpp +++ b/torch_xla/csrc/ops/nll_loss2d.cpp @@ -37,7 +37,7 @@ xla::Shape NodeOutputShape(const Value& logits, const Value& labels, NllLoss2d::NllLoss2d(const Value& logits, const Value& labels, const absl::optional& weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss2d), + : Node(torch::lazy::OpKind(at::aten::nll_loss2d), xla::util::GetValuesVector({logits, labels}, {&weight}), [&]() { return NodeOutputShape(logits, labels, weight, reduction, diff --git a/torch_xla/csrc/ops/nll_loss2d_backward.cpp b/torch_xla/csrc/ops/nll_loss2d_backward.cpp index 26e62000de36..573acc4887e0 100644 --- a/torch_xla/csrc/ops/nll_loss2d_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss2d_backward.cpp @@ -45,7 +45,7 @@ NllLoss2dBackward::NllLoss2dBackward(const Value& grad_output, const absl::optional& weight, const absl::optional& total_weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss2d_backward), + : Node(torch::lazy::OpKind(at::aten::nll_loss2d_backward), xla::util::GetValuesVector({grad_output, logits, labels}, {&weight, &total_weight}), [&]() { diff --git a/torch_xla/csrc/ops/nll_loss_backward.cpp b/torch_xla/csrc/ops/nll_loss_backward.cpp index 88cabb7982b5..37d2b4162784 100644 --- a/torch_xla/csrc/ops/nll_loss_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss_backward.cpp @@ -45,7 +45,7 @@ NllLossBackward::NllLossBackward(const Value& grad_output, const Value& logits, const absl::optional& weight, const absl::optional& total_weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss_backward), + : Node(torch::lazy::OpKind(at::aten::nll_loss_backward), xla::util::GetValuesVector({grad_output, logits, labels}, {&weight, &total_weight}), [&]() { diff --git a/torch_xla/csrc/ops/nonzero.cpp b/torch_xla/csrc/ops/nonzero.cpp index db7a6e523c70..81032f4a6666 100644 --- a/torch_xla/csrc/ops/nonzero.cpp +++ b/torch_xla/csrc/ops/nonzero.cpp @@ -24,7 +24,8 @@ xla::Shape NodeOutputShape(const Value& input) { } // namespace NonZero::NonZero(const Value& input) - : Node(ir::OpKind(at::aten::nonzero), {input}, NodeOutputShape(input), + : Node(torch::lazy::OpKind(at::aten::nonzero), {input}, + NodeOutputShape(input), /*num_outputs=*/2) {} NodePtr NonZero::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/normal.cpp b/torch_xla/csrc/ops/normal.cpp index 925613a57e99..20e2f1938385 100644 --- a/torch_xla/csrc/ops/normal.cpp +++ b/torch_xla/csrc/ops/normal.cpp @@ -9,7 +9,8 @@ namespace ir { namespace ops { Normal::Normal(const Value& mean, const Value& std, const Value& seed) - : Node(ir::OpKind(at::aten::normal), {mean, std, seed}, mean.shape()) {} + : Node(torch::lazy::OpKind(at::aten::normal), {mean, std, seed}, + mean.shape()) {} NodePtr Normal::Clone(OpList operands) const { return MakeNode(operands.at(0), operands.at(1), operands.at(2)); diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index f53547389e91..c507f9bb1a84 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -34,15 +34,15 @@ namespace torch_xla { namespace ir { namespace ops { -#define PTXLA_UNARY_OP(name, sym, xla_fn) \ - NodePtr name(const Value& input) { \ - auto lower_fn = [](const Node& node, \ - LoweringContext* loctx) -> XlaOpVector { \ - xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); \ - return node.ReturnOp(xla_fn(xla_input), loctx); \ - }; \ - return GenericOp(OpKind(sym), {input}, input.shape(), \ - std::move(lower_fn)); \ +#define PTXLA_UNARY_OP(name, sym, xla_fn) \ + NodePtr name(const Value& input) { \ + auto lower_fn = [](const Node& node, \ + LoweringContext* loctx) -> XlaOpVector { \ + xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); \ + return node.ReturnOp(xla_fn(xla_input), loctx); \ + }; \ + return GenericOp(torch::lazy::OpKind(sym), {input}, input.shape(), \ + std::move(lower_fn)); \ } #define PTXLA_BINARY_OP(name, sym, xla_fn) \ @@ -59,7 +59,7 @@ namespace ops { return node.ReturnOp(xla_fn(promoted.first, promoted.second), loctx); \ }; \ return GenericOp( \ - OpKind(sym), {input0, input1}, \ + torch::lazy::OpKind(sym), {input0, input1}, \ [&]() { \ return InferOutputShape({input0.shape(), input1.shape()}, shape_fn); \ }, \ @@ -104,7 +104,7 @@ NodePtr Trunc(const Value& input) { return Floor(Abs(input)) * SignOp(input); } NodePtr FracOp(const Value& input) { return input - Trunc(input); } -NodePtr LogBase(const Value& input, OpKind op, double base) { +NodePtr LogBase(const Value& input, torch::lazy::OpKind op, double base) { auto lower_fn = [base](const Node& node, LoweringContext* loctx) -> XlaOpVector { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); @@ -122,8 +122,8 @@ NodePtr ReciprocalOp(const Value& input) { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); return node.ReturnOp(BuildReciprocal(xla_input), loctx); }; - return GenericOp(OpKind(at::aten::reciprocal), {input}, input.shape(), - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::reciprocal), {input}, + input.shape(), std::move(lower_fn)); } NodePtr SgnOp(const Value& input) { @@ -131,7 +131,7 @@ NodePtr SgnOp(const Value& input) { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); return node.ReturnOp(BuildSgn(xla_input), loctx); }; - return GenericOp(OpKind(at::aten::sgn), {input}, input.shape(), + return GenericOp(torch::lazy::OpKind(at::aten::sgn), {input}, input.shape(), std::move(lower_fn)); } @@ -140,7 +140,7 @@ NodePtr SignOp(const Value& input) { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); return node.ReturnOp(BuildSign(xla_input), loctx); }; - return GenericOp(OpKind(at::aten::sign), {input}, input.shape(), + return GenericOp(torch::lazy::OpKind(at::aten::sign), {input}, input.shape(), std::move(lower_fn)); } @@ -149,7 +149,7 @@ NodePtr Abs(const Value& input) { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); return node.ReturnOp(BuildAbs(xla_input), loctx); }; - return GenericOp(OpKind(at::aten::abs), {input}, input.shape(), + return GenericOp(torch::lazy::OpKind(at::aten::abs), {input}, input.shape(), std::move(lower_fn)); } @@ -165,7 +165,7 @@ NodePtr ReluOp(const Value& input) { return BuildRelu(operands[0]); }; return GenericOp( - OpKind(at::aten::relu), {input}, + torch::lazy::OpKind(at::aten::relu), {input}, [&]() { return InferOutputShape({input.shape()}, lower_for_shape_fn); }, std::move(lower_fn)); } @@ -178,8 +178,8 @@ NodePtr Prelu(const Value& input, const Value& weight) { return node.ReturnOp(xla_output, loctx); }; - return GenericOp(OpKind(at::aten::prelu), {input, weight}, input.shape(), - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::prelu), {input, weight}, + input.shape(), std::move(lower_fn)); } NodePtr HardSigmoid(const Value& input) { @@ -187,8 +187,8 @@ NodePtr HardSigmoid(const Value& input) { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); return node.ReturnOp(BuildHardSigmoid(xla_input), loctx); }; - return GenericOp(OpKind(at::aten::hardsigmoid), {input}, input.shape(), - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::hardsigmoid), {input}, + input.shape(), std::move(lower_fn)); } NodePtr HardSigmoidBackward(const Value& grad_output, const Value& input) { @@ -198,8 +198,8 @@ NodePtr HardSigmoidBackward(const Value& grad_output, const Value& input) { return node.ReturnOp(BuildHardSigmoidBackward(xla_grad_output, xla_input), loctx); }; - return GenericOp(OpKind(at::aten::hardsigmoid_backward), {grad_output, input}, - input.shape(), std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::hardsigmoid_backward), + {grad_output, input}, input.shape(), std::move(lower_fn)); } std::tuple LogSigmoid(const Value& input) { @@ -229,7 +229,7 @@ NodePtr SiLU(const Value& input) { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); return node.ReturnOp(xla_input * BuildSigmoid(xla_input), loctx); }; - return GenericOp(OpKind(at::aten::silu), {input}, input.shape(), + return GenericOp(torch::lazy::OpKind(at::aten::silu), {input}, input.shape(), std::move(lower_fn)); } @@ -243,13 +243,13 @@ NodePtr SiLUBackward(const Value& grad_output, const Value& input) { [](absl::Span operands) -> xla::XlaOp { return BuildSiLUBackward(operands[0], operands[1]); }; - return GenericOp(OpKind(at::aten::silu_backward), {grad_output, input}, - [&]() { - return InferOutputShape( - {grad_output.shape(), input.shape()}, - lower_for_shape_fn); - }, - std::move(lower_fn)); + return GenericOp( + torch::lazy::OpKind(at::aten::silu_backward), {grad_output, input}, + [&]() { + return InferOutputShape({grad_output.shape(), input.shape()}, + lower_for_shape_fn); + }, + std::move(lower_fn)); } NodePtr Sigmoid(const Value& input) { @@ -257,8 +257,8 @@ NodePtr Sigmoid(const Value& input) { xla::XlaOp xla_input = loctx->GetOutputOp(node.operand(0)); return node.ReturnOp(BuildSigmoid(xla_input), loctx); }; - return GenericOp(OpKind(at::aten::sigmoid), {input}, input.shape(), - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::sigmoid), {input}, + input.shape(), std::move(lower_fn)); } NodePtr SigmoidBackward(const Value& grad_output, const Value& output) { @@ -291,8 +291,8 @@ NodePtr Clamp(const Value& input, const Value& min, const Value& max) { /*device=*/nullptr); return node.ReturnOp(xla::Clamp(xla_min, xla_input, xla_max), loctx); }; - return GenericOp(OpKind(at::aten::clamp), {input, min, max}, input.shape(), - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::clamp), {input, min, max}, + input.shape(), std::move(lower_fn)); } NodePtr Ger(const Value& input, const Value& other) { @@ -305,7 +305,7 @@ NodePtr Ger(const Value& input, const Value& other) { [](absl::Span operands) -> xla::XlaOp { return BuildGer(operands[0], operands[1]); }; - return GenericOp(OpKind(at::aten::ger), {input, other}, + return GenericOp(torch::lazy::OpKind(at::aten::ger), {input, other}, [&]() { return InferOutputShape({input.shape(), other.shape()}, lower_for_shape_fn); @@ -326,7 +326,7 @@ NodePtr AddMatMulOp(const Value& input, const Value& weight, [](absl::Span operands) -> xla::XlaOp { return BuildMatMul(operands[0], operands[1], operands[2]); }; - return GenericOp(OpKind(at::aten::addmm), {input, weight, bias}, + return GenericOp(torch::lazy::OpKind(at::aten::addmm), {input, weight, bias}, [&]() { return InferOutputShape( {input.shape(), weight.shape(), bias.shape()}, @@ -345,7 +345,7 @@ NodePtr Dot(const Value& input, const Value& weight) { [](absl::Span operands) -> xla::XlaOp { return BuildDot(operands[0], operands[1]); }; - return GenericOp(OpKind(at::aten::mm), {input, weight}, + return GenericOp(torch::lazy::OpKind(at::aten::mm), {input, weight}, [&]() { return InferOutputShape({input.shape(), weight.shape()}, lower_for_shape_fn); @@ -366,7 +366,7 @@ NodePtr MatMul(const Value& lhs, const Value& rhs) { return CreateMatMul(operands[0], operands[1]); }; return GenericOp( - OpKind(at::aten::matmul), {lhs, rhs}, + torch::lazy::OpKind(at::aten::matmul), {lhs, rhs}, [&]() { return InferOutputShape({lhs.shape(), rhs.shape()}, lower_for_shape_fn); }, @@ -389,13 +389,14 @@ NodePtr AdaptiveMaxPool2dBackward(const Value& grad_output, /*input=*/operands[1], /*pool_dim=*/2); }; - return GenericOp( - OpKind(at::aten::adaptive_max_pool2d_backward), {grad_output, input}, - [&]() { - return InferOutputShape({grad_output.shape(), input.shape()}, - lower_for_shape_fn); - }, - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::adaptive_max_pool2d_backward), + {grad_output, input}, + [&]() { + return InferOutputShape( + {grad_output.shape(), input.shape()}, + lower_for_shape_fn); + }, + std::move(lower_fn)); } NodePtr AdaptiveAvgPool3dBackward(const Value& grad_output, @@ -413,13 +414,14 @@ NodePtr AdaptiveAvgPool3dBackward(const Value& grad_output, return BuildAdaptiveAvgPool3dBackward(/*out_backprop=*/operands[0], /*input=*/operands[1]); }; - return GenericOp( - OpKind(at::aten::adaptive_avg_pool3d_backward), {grad_output, input}, - [&]() { - return InferOutputShape({grad_output.shape(), input.shape()}, - lower_for_shape_fn); - }, - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::adaptive_avg_pool3d_backward), + {grad_output, input}, + [&]() { + return InferOutputShape( + {grad_output.shape(), input.shape()}, + lower_for_shape_fn); + }, + std::move(lower_fn)); } NodePtr AdaptiveAvgPool2dBackward(const Value& grad_output, @@ -437,13 +439,14 @@ NodePtr AdaptiveAvgPool2dBackward(const Value& grad_output, return BuildAdaptiveAvgPool2dBackward(/*out_backprop=*/operands[0], /*input=*/operands[1]); }; - return GenericOp( - OpKind(at::aten::adaptive_avg_pool2d_backward), {grad_output, input}, - [&]() { - return InferOutputShape({grad_output.shape(), input.shape()}, - lower_for_shape_fn); - }, - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::adaptive_avg_pool2d_backward), + {grad_output, input}, + [&]() { + return InferOutputShape( + {grad_output.shape(), input.shape()}, + lower_for_shape_fn); + }, + std::move(lower_fn)); } NodePtr ComparisonOp(c10::Symbol kind, const Value& input, const Value& other) { @@ -458,7 +461,7 @@ NodePtr ComparisonOp(c10::Symbol kind, const Value& input, const Value& other) { [kind](absl::Span operands) -> xla::XlaOp { return BuildComparisonOp(kind, operands[0], operands[1]); }; - return GenericOp(OpKind(kind), {input, other}, + return GenericOp(torch::lazy::OpKind(kind), {input, other}, [&]() { return InferOutputShape({input.shape(), other.shape()}, lower_for_shape_fn); @@ -479,8 +482,9 @@ NodePtr Where(const Value& condition, const Value& input, const Value& other) { promoted_branches.second), loctx); }; - return GenericOp(OpKind(at::aten::where), {condition, input, other}, - input.shape(), std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::where), + {condition, input, other}, input.shape(), + std::move(lower_fn)); } NodePtr ARange(const at::Scalar& start, const at::Scalar& end, @@ -570,7 +574,7 @@ NodePtr BroadcastTensors(absl::Span tensors) { return xla::Tuple(results.front().builder(), results); }; return GenericOp( - OpKind(at::aten::broadcast_tensors), tensors, + torch::lazy::OpKind(at::aten::broadcast_tensors), tensors, [&]() { return InferOutputShape(tensor_shapes, lower_for_shape_fn); }, std::move(lower_fn), /*num_outputs=*/tensors.size()); } @@ -618,7 +622,7 @@ NodePtr Identity(int64_t lines, int64_t cols, xla::PrimitiveType element_type) { xla::IdentityMatrix(loctx->builder(), element_type, lines, cols), loctx); }; - return GenericOp(OpKind(at::aten::eye), + return GenericOp(torch::lazy::OpKind(at::aten::eye), xla::ShapeUtil::MakeShape(element_type, {lines, cols}), std::move(lower_fn), /*num_outputs=*/1, torch::lazy::MHash(lines, cols)); @@ -710,7 +714,7 @@ NodePtr MaxUnary(const Value& input) { return node.ReturnOp(xla::Reshape(result, {}), loctx); }; XLA_CHECK_GT(xla::ShapeUtil::ElementsIn(input.shape()), 0); - return GenericOp(OpKind(at::aten::max), {input}, + return GenericOp(torch::lazy::OpKind(at::aten::max), {input}, xla::ShapeUtil::MakeShape(input.shape().element_type(), {}), std::move(lower_fn)); } @@ -729,7 +733,7 @@ NodePtr MinUnary(const Value& input) { return node.ReturnOp(xla::Reshape(result, {}), loctx); }; XLA_CHECK_GT(xla::ShapeUtil::ElementsIn(input.shape()), 0); - return GenericOp(OpKind(at::aten::min), {input}, + return GenericOp(torch::lazy::OpKind(at::aten::min), {input}, xla::ShapeUtil::MakeShape(input.shape().element_type(), {}), std::move(lower_fn)); } @@ -743,7 +747,7 @@ NodePtr Take(const Value& input, const Value& index) { }; xla::Shape result_shape = index.shape(); result_shape.set_element_type(input.shape().element_type()); - return GenericOp(OpKind(at::aten::take), {input, index}, + return GenericOp(torch::lazy::OpKind(at::aten::take), {input, index}, std::move(result_shape), std::move(lower_fn)); } @@ -801,7 +805,7 @@ NodePtr LogDet(const Value& input) { xla::Shape logdet_shape(input_shape); logdet_shape.DeleteDimension(input_shape.rank() - 1); logdet_shape.DeleteDimension(input_shape.rank() - 2); - return GenericOp(OpKind(at::aten::logdet), {input}, logdet_shape, + return GenericOp(torch::lazy::OpKind(at::aten::logdet), {input}, logdet_shape, std::move(lower_fn)); } @@ -811,8 +815,8 @@ NodePtr Inverse(const Value& input) { xla::XlaOp result = BuildInverse(xla_input); return node.ReturnOp(result, loctx); }; - return GenericOp(OpKind(at::aten::inverse), {input}, input.shape(), - std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::inverse), {input}, + input.shape(), std::move(lower_fn)); } NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias, @@ -835,7 +839,7 @@ NodePtr BaddBmm(const Value& lhs, const Value& rhs, const Value& bias, return BuildMatMulWithMultiplier(operands[0], operands[1], operands[2], operands[3], operands[4]); }; - return GenericOp(OpKind(at::aten::baddbmm), + return GenericOp(torch::lazy::OpKind(at::aten::baddbmm), {lhs, rhs, bias, product_multiplier, bias_multiplier}, [&]() { return InferOutputShape( @@ -863,7 +867,7 @@ NodePtr LogicalNot(const Value& input) { operands[0], [](xla::XlaOp lhs) { return xla::Not(lhs); }); }; return GenericOp( - OpKind(at::aten::logical_not), {input}, + torch::lazy::OpKind(at::aten::logical_not), {input}, [&]() { return InferOutputShape({input.shape()}, shape_fn); }, std::move(lower_fn)); } @@ -884,7 +888,7 @@ NodePtr LogicalXor(const Value& input, const Value& other) { [](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Xor(lhs, rhs); }); }; return GenericOp( - OpKind(at::aten::logical_xor), {input, other}, + torch::lazy::OpKind(at::aten::logical_xor), {input, other}, [&]() { return InferOutputShape({input.shape(), other.shape()}, shape_fn); }, @@ -907,7 +911,7 @@ NodePtr LogicalAnd(const Value& input, const Value& other) { [](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::And(lhs, rhs); }); }; return GenericOp( - OpKind(at::aten::logical_and), {input, other}, + torch::lazy::OpKind(at::aten::logical_and), {input, other}, [&]() { return InferOutputShape({input.shape(), other.shape()}, shape_fn); }, @@ -930,7 +934,7 @@ NodePtr LogicalOr(const Value& input, const Value& other) { [](xla::XlaOp lhs, xla::XlaOp rhs) { return xla::Or(lhs, rhs); }); }; return GenericOp( - OpKind(at::aten::logical_or), {input, other}, + torch::lazy::OpKind(at::aten::logical_or), {input, other}, [&]() { return InferOutputShape({input.shape(), other.shape()}, shape_fn); }, @@ -949,7 +953,7 @@ NodePtr XLogY(const Value& input, const Value& other) { XLA_CHECK_EQ(operands.size(), 2) << "Unexpected number of operands"; return BuildXLogY(operands[0], operands[1]); }; - return GenericOp(OpKind(at::aten::xlogy), {input, other}, + return GenericOp(torch::lazy::OpKind(at::aten::xlogy), {input, other}, [&]() { return InferOutputShape({input.shape(), other.shape()}, lower_for_shape_fn); @@ -971,8 +975,9 @@ NodePtr NanToNum(const Value& input, const Value& nan, const Value& posinf, neginf_replacement, xla_input))); return node.ReturnOp(result, loctx); }; - return GenericOp(OpKind(at::aten::nan_to_num), {input, nan, posinf, neginf}, - input.shape(), std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::nan_to_num), + {input, nan, posinf, neginf}, input.shape(), + std::move(lower_fn)); } NodePtr SLogDet(const Value& input) { @@ -989,7 +994,7 @@ NodePtr SLogDet(const Value& input) { }; return GenericOp( - OpKind(at::aten::slogdet), {input}, + torch::lazy::OpKind(at::aten::slogdet), {input}, [&]() { return InferOutputShape({input.shape()}, lower_for_shape_fn); }, std::move(lower_fn), /*num_outputs=*/2); } @@ -1004,8 +1009,9 @@ NodePtr Softplus(const Value& input, const Value& beta, return node.ReturnOp(xla_output, loctx); }; - return GenericOp(OpKind(at::aten::softplus), {input, beta, threshold}, - input.shape(), std::move(lower_fn)); + return GenericOp(torch::lazy::OpKind(at::aten::softplus), + {input, beta, threshold}, input.shape(), + std::move(lower_fn)); } } // namespace ops diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 80551e4627c6..681a34a7636c 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -26,7 +26,7 @@ inline NodePtr ConstantOp(xla::Literal value) { } inline NodePtr GenericOp( - OpKind op, absl::Span operands, xla::Shape shape, + torch::lazy::OpKind op, absl::Span operands, xla::Shape shape, Generic::LowerFn lower_fn, size_t num_outputs = 1, // cast to uint32_t to avoid ambiguous constructor of uint128 torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { @@ -36,7 +36,7 @@ inline NodePtr GenericOp( } inline NodePtr GenericOp( - OpKind op, absl::Span operands, + torch::lazy::OpKind op, absl::Span operands, const std::function& shape_fn, Generic::LowerFn lower_fn, size_t num_outputs = 1, // cast to uint32_t to avoid ambiguous constructor of uint128 @@ -46,8 +46,9 @@ inline NodePtr GenericOp( hash_seed); } -inline NodePtr GenericOp(OpKind op, xla::Shape shape, Generic::LowerFn lower_fn, - size_t num_outputs, torch::lazy::hash_t hash_seed) { +inline NodePtr GenericOp(torch::lazy::OpKind op, xla::Shape shape, + Generic::LowerFn lower_fn, size_t num_outputs, + torch::lazy::hash_t hash_seed) { return torch_xla::ir::MakeNode(std::move(op), std::move(shape), std::move(lower_fn), num_outputs, hash_seed); @@ -105,7 +106,7 @@ NodePtr Erfinv(const Value& input); NodePtr Log(const Value& input); -NodePtr LogBase(const Value& input, OpKind op, double base); +NodePtr LogBase(const Value& input, torch::lazy::OpKind op, double base); NodePtr Log1p(const Value& input); diff --git a/torch_xla/csrc/ops/permute.cpp b/torch_xla/csrc/ops/permute.cpp index 9e35b3598d60..21ba22dede13 100644 --- a/torch_xla/csrc/ops/permute.cpp +++ b/torch_xla/csrc/ops/permute.cpp @@ -22,7 +22,7 @@ xla::Shape NodeOutputShape(const Value& input, absl::Span dims) { } // namespace Permute::Permute(const Value& input, std::vector dims) - : Node(ir::OpKind(at::aten::permute), {input}, + : Node(torch::lazy::OpKind(at::aten::permute), {input}, [&]() { return NodeOutputShape(input, dims); }, /*num_outputs=*/1, torch::lazy::MHash(dims)), dims_(std::move(dims)) {} diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index d898d8dc0a06..48aac632c1df 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -43,7 +43,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, Prod::Prod(const Value& input, std::vector dimensions, bool keep_reduced_dimensions, c10::optional dtype) - : Node(ir::OpKind(at::aten::prod), {input}, + : Node(torch::lazy::OpKind(at::aten::prod), {input}, [&]() { return NodeOutputShape(input, dimensions, keep_reduced_dimensions, dtype); diff --git a/torch_xla/csrc/ops/put.cpp b/torch_xla/csrc/ops/put.cpp index 92655c66f0a1..5abf8a8d14dd 100644 --- a/torch_xla/csrc/ops/put.cpp +++ b/torch_xla/csrc/ops/put.cpp @@ -9,7 +9,8 @@ namespace ops { Put::Put(const Value& input, const Value& index, const Value& source, bool accumulate) - : Node(ir::OpKind(at::aten::put), {input, index, source}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::put), {input, index, source}, + input.shape(), /*num_outputs=*/1, torch::lazy::MHash(accumulate)), accumulate_(accumulate) {} diff --git a/torch_xla/csrc/ops/qr.cpp b/torch_xla/csrc/ops/qr.cpp index 1a95ef8df2b8..35ba2754cb18 100644 --- a/torch_xla/csrc/ops/qr.cpp +++ b/torch_xla/csrc/ops/qr.cpp @@ -42,7 +42,7 @@ xla::Shape NodeOutputShape(const Value& input, bool some) { } // namespace QR::QR(const Value& input, bool some) - : Node(ir::OpKind(at::aten::qr), {input}, + : Node(torch::lazy::OpKind(at::aten::qr), {input}, [&]() { return NodeOutputShape(input, some); }, /*num_outputs=*/2, torch::lazy::MHash(some)), some_(some) {} diff --git a/torch_xla/csrc/ops/reflection_pad2d.cpp b/torch_xla/csrc/ops/reflection_pad2d.cpp index 3db5b9eaff40..5bf0b4646ec8 100644 --- a/torch_xla/csrc/ops/reflection_pad2d.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(const Value& input, ReflectionPad2d::ReflectionPad2d(const Value& input, std::vector padding) - : Node(OpKind(at::aten::reflection_pad2d), {input}, + : Node(torch::lazy::OpKind(at::aten::reflection_pad2d), {input}, [&]() { return NodeOutputShape(input, padding); }, /*num_outputs=*/1, torch::lazy::MHash(padding)), padding_(std::move(padding)) {} diff --git a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp index cd4f1e2438f3..6debc866f48c 100644 --- a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp @@ -25,7 +25,8 @@ xla::Shape NodeOutputShape(const Value& grad_output, const Value& input, ReflectionPad2dBackward::ReflectionPad2dBackward(const Value& grad_output, const Value& input, std::vector padding) - : Node(OpKind(at::aten::reflection_pad2d_backward), {grad_output, input}, + : Node(torch::lazy::OpKind(at::aten::reflection_pad2d_backward), + {grad_output, input}, [&]() { return NodeOutputShape(grad_output, input, padding); }, /*num_outputs=*/1, torch::lazy::MHash(padding)), padding_(std::move(padding)) {} diff --git a/torch_xla/csrc/ops/repeat.cpp b/torch_xla/csrc/ops/repeat.cpp index b59b89005a6e..6a546b66d296 100644 --- a/torch_xla/csrc/ops/repeat.cpp +++ b/torch_xla/csrc/ops/repeat.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(const Value& input, } // namespace Repeat::Repeat(const Value& input, std::vector repeats) - : Node(ir::OpKind(at::aten::repeat), {input}, + : Node(torch::lazy::OpKind(at::aten::repeat), {input}, [&]() { return NodeOutputShape(input, repeats); }, /*num_outputs=*/1, torch::lazy::MHash(repeats)), repeats_(std::move(repeats)) {} diff --git a/torch_xla/csrc/ops/resize.cpp b/torch_xla/csrc/ops/resize.cpp index 360c1e776f95..10368a400393 100644 --- a/torch_xla/csrc/ops/resize.cpp +++ b/torch_xla/csrc/ops/resize.cpp @@ -17,7 +17,7 @@ xla::Shape NodeOutputShape(const Value& input, absl::Span size) { } // namespace Resize::Resize(const Value& input, std::vector size) - : Node(ir::OpKind(at::aten::resize), {input}, + : Node(torch::lazy::OpKind(at::aten::resize), {input}, [&]() { return NodeOutputShape(input, size); }, /*num_outputs=*/1, torch::lazy::MHash(size)), size_(std::move(size)) {} diff --git a/torch_xla/csrc/ops/rrelu_with_noise.cpp b/torch_xla/csrc/ops/rrelu_with_noise.cpp index 1b01b227e5c7..6cec11a964ef 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise.cpp @@ -12,7 +12,7 @@ namespace ops { RreluWithNoise::RreluWithNoise(const Value& input, const Value& seed, const at::Scalar& lower, const at::Scalar& upper, bool training) - : Node(ir::OpKind(at::aten::rrelu_with_noise), {input, seed}, + : Node(torch::lazy::OpKind(at::aten::rrelu_with_noise), {input, seed}, xla::ShapeUtil::MakeTupleShape({input.shape(), input.shape()}), /*num_outputs=*/2, torch::lazy::MHash(ScalarHash(lower), ScalarHash(upper), training)), diff --git a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp index 99f4c232e063..a4f0cedd1992 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp @@ -11,7 +11,7 @@ namespace ops { RreluWithNoiseBackward::RreluWithNoiseBackward( const Value& grad_output, const Value& input, const Value& noise, const at::Scalar& lower, const at::Scalar& upper, bool training) - : Node(ir::OpKind(at::aten::rrelu_with_noise_backward), + : Node(torch::lazy::OpKind(at::aten::rrelu_with_noise_backward), {grad_output, input, noise}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(ScalarHash(lower), ScalarHash(upper), training)), diff --git a/torch_xla/csrc/ops/scalar.cpp b/torch_xla/csrc/ops/scalar.cpp index ec3d6cfcc975..96c9d86e27ee 100644 --- a/torch_xla/csrc/ops/scalar.cpp +++ b/torch_xla/csrc/ops/scalar.cpp @@ -13,12 +13,13 @@ namespace ir { namespace ops { Scalar::Scalar(const at::Scalar& value, xla::Shape shape) - : Node(OpKind(at::prim::Constant), std::move(shape), /*num_outputs=*/1, - ScalarHash(value)), + : Node(torch::lazy::OpKind(at::prim::Constant), std::move(shape), + /*num_outputs=*/1, ScalarHash(value)), value_(std::move(value)) {} Scalar::Scalar(const at::Scalar& value, xla::PrimitiveType type) - : Node(OpKind(at::prim::Constant), xla::ShapeUtil::MakeShape(type, {}), + : Node(torch::lazy::OpKind(at::prim::Constant), + xla::ShapeUtil::MakeShape(type, {}), /*num_outputs=*/1, ScalarHash(value)), value_(std::move(value)) {} diff --git a/torch_xla/csrc/ops/scatter.cpp b/torch_xla/csrc/ops/scatter.cpp index 66e1c12d10ce..a356083d0631 100644 --- a/torch_xla/csrc/ops/scatter.cpp +++ b/torch_xla/csrc/ops/scatter.cpp @@ -9,7 +9,8 @@ namespace ops { Scatter::Scatter(const Value& input, const Value& index, const Value& src, int64_t dim) - : Node(ir::OpKind(at::aten::scatter), {input, index, src}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::scatter), {input, index, src}, + input.shape(), /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/scatter_add.cpp b/torch_xla/csrc/ops/scatter_add.cpp index d51cb3d18a1c..de764178701c 100644 --- a/torch_xla/csrc/ops/scatter_add.cpp +++ b/torch_xla/csrc/ops/scatter_add.cpp @@ -11,7 +11,7 @@ namespace ops { ScatterAdd::ScatterAdd(const Value& input, const Value& index, const Value& src, int64_t dim) - : Node(ir::OpKind(at::aten::scatter_add), {input, index, src}, + : Node(torch::lazy::OpKind(at::aten::scatter_add), {input, index, src}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/shrink_backward.cpp b/torch_xla/csrc/ops/shrink_backward.cpp index 75d10f829dfa..2e220aaf16b2 100644 --- a/torch_xla/csrc/ops/shrink_backward.cpp +++ b/torch_xla/csrc/ops/shrink_backward.cpp @@ -9,8 +9,9 @@ namespace torch_xla { namespace ir { namespace ops { -ShrinkBackward::ShrinkBackward(OpKind kind, const Value& grad_output, - const Value& input, const at::Scalar& lambda) +ShrinkBackward::ShrinkBackward(torch::lazy::OpKind kind, + const Value& grad_output, const Value& input, + const at::Scalar& lambda) : Node(kind, {grad_output, input}, input.shape(), /*num_outputs=*/1, ScalarHash(lambda)), lambda_(std::move(lambda)) {} @@ -22,8 +23,8 @@ std::string ShrinkBackward::ToString() const { } NodePtr ShrinkBackward::Clone(OpList operands) const { - return MakeNode(op(), operands.at(0), operands.at(1), - lambda_); + return ir::MakeNode(op(), operands.at(0), operands.at(1), + lambda_); } XlaOpVector ShrinkBackward::Lower(LoweringContext* loctx) const { diff --git a/torch_xla/csrc/ops/shrink_backward.h b/torch_xla/csrc/ops/shrink_backward.h index d9f69de633ad..79f56de4867c 100644 --- a/torch_xla/csrc/ops/shrink_backward.h +++ b/torch_xla/csrc/ops/shrink_backward.h @@ -10,8 +10,8 @@ namespace ops { class ShrinkBackward : public Node { public: - ShrinkBackward(OpKind kind, const Value& grad_output, const Value& input, - const at::Scalar& lambda); + ShrinkBackward(torch::lazy::OpKind kind, const Value& grad_output, + const Value& input, const at::Scalar& lambda); std::string ToString() const override; diff --git a/torch_xla/csrc/ops/softmax.cpp b/torch_xla/csrc/ops/softmax.cpp index 38786fc7f68e..d339ec317b53 100644 --- a/torch_xla/csrc/ops/softmax.cpp +++ b/torch_xla/csrc/ops/softmax.cpp @@ -31,7 +31,7 @@ xla::Shape NodeOutputShape(const Value& input, Softmax::Softmax(const Value& input, int64_t dim, c10::optional dtype) - : Node(ir::OpKind(at::aten::softmax), {input}, + : Node(torch::lazy::OpKind(at::aten::softmax), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, torch::lazy::MHash(dim, torch::lazy::OptionalOr(dtype, -1))), diff --git a/torch_xla/csrc/ops/softmax_backward.cpp b/torch_xla/csrc/ops/softmax_backward.cpp index 471367f63d60..7fe39c4eecea 100644 --- a/torch_xla/csrc/ops/softmax_backward.cpp +++ b/torch_xla/csrc/ops/softmax_backward.cpp @@ -11,8 +11,8 @@ namespace ops { SoftmaxBackward::SoftmaxBackward(const Value& grad_output, const Value& output, int64_t dim) - : Node(ir::OpKind(at::aten::_softmax_backward_data), {grad_output, output}, - grad_output.shape(), + : Node(torch::lazy::OpKind(at::aten::_softmax_backward_data), + {grad_output, output}, grad_output.shape(), /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/softshrink.cpp b/torch_xla/csrc/ops/softshrink.cpp index 95df051b2750..327471157246 100644 --- a/torch_xla/csrc/ops/softshrink.cpp +++ b/torch_xla/csrc/ops/softshrink.cpp @@ -10,7 +10,7 @@ namespace ir { namespace ops { Softshrink::Softshrink(const Value& input, const at::Scalar& lambda) - : Node(OpKind(at::aten::softshrink), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::softshrink), {input}, input.shape(), /*num_outputs=*/1, ScalarHash(lambda)), lambda_(std::move(lambda)) {} diff --git a/torch_xla/csrc/ops/split.cpp b/torch_xla/csrc/ops/split.cpp index ee96de21bbee..abac5ccefe1d 100644 --- a/torch_xla/csrc/ops/split.cpp +++ b/torch_xla/csrc/ops/split.cpp @@ -25,7 +25,7 @@ xla::Shape NodeOutputShape(const Value& input, } // namespace Split::Split(const Value& input, std::vector split_sizes, int64_t dim) - : Node(ir::OpKind(at::aten::split), {input}, + : Node(torch::lazy::OpKind(at::aten::split), {input}, [&]() { return NodeOutputShape(input, split_sizes, dim); }, ComputeSplitCount(input.shape().dimensions(dim), split_sizes), torch::lazy::MHash(split_sizes, dim)), diff --git a/torch_xla/csrc/ops/squeeze.cpp b/torch_xla/csrc/ops/squeeze.cpp index c25b5be31de8..0ea1c241e650 100644 --- a/torch_xla/csrc/ops/squeeze.cpp +++ b/torch_xla/csrc/ops/squeeze.cpp @@ -30,7 +30,7 @@ xla::Shape NodeOutputShape(const Value& input, int dim) { } // namespace Squeeze::Squeeze(const Value& input, int dim) - : Node(ir::OpKind(at::aten::squeeze), {input}, + : Node(torch::lazy::OpKind(at::aten::squeeze), {input}, [&]() { return NodeOutputShape(input, dim); }, /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/stack.cpp b/torch_xla/csrc/ops/stack.cpp index a6f98768df51..64b7fd8b1b08 100644 --- a/torch_xla/csrc/ops/stack.cpp +++ b/torch_xla/csrc/ops/stack.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(absl::Span values, int64_t dim) { } // namespace Stack::Stack(absl::Span values, int64_t dim) - : Node(ir::OpKind(at::aten::stack), values, + : Node(torch::lazy::OpKind(at::aten::stack), values, [&]() { return NodeOutputShape(values, dim); }, /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/std.cpp b/torch_xla/csrc/ops/std.cpp index 87cd52b40023..9d773bbc5b86 100644 --- a/torch_xla/csrc/ops/std.cpp +++ b/torch_xla/csrc/ops/std.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, Std::Std(const Value& input, std::vector dimensions, bool keep_reduced_dimensions, int64_t correction) - : Node(ir::OpKind(at::aten::std), {input}, + : Node(torch::lazy::OpKind(at::aten::std), {input}, [&]() { return NodeOutputShape(input, dimensions, keep_reduced_dimensions, correction); diff --git a/torch_xla/csrc/ops/std_mean.cpp b/torch_xla/csrc/ops/std_mean.cpp index b8de828f0221..bc3a99fdd5c2 100644 --- a/torch_xla/csrc/ops/std_mean.cpp +++ b/torch_xla/csrc/ops/std_mean.cpp @@ -27,7 +27,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, StdMean::StdMean(const Value& input, std::vector dimensions, int64_t correction, bool keep_reduced_dimensions) - : Node(ir::OpKind(at::aten::std_mean), {input}, + : Node(torch::lazy::OpKind(at::aten::std_mean), {input}, [&]() { return NodeOutputShape(input, dimensions, keep_reduced_dimensions, correction); diff --git a/torch_xla/csrc/ops/sum.cpp b/torch_xla/csrc/ops/sum.cpp index b7c3fb8245d5..5c7531f760b4 100644 --- a/torch_xla/csrc/ops/sum.cpp +++ b/torch_xla/csrc/ops/sum.cpp @@ -37,7 +37,7 @@ xla::Shape NodeOutputShape(const Value& input, Sum::Sum(const Value& input, std::vector dimensions, bool keep_reduced_dimensions, c10::optional dtype) - : Node(ir::OpKind(at::aten::sum), {input}, + : Node(torch::lazy::OpKind(at::aten::sum), {input}, [&]() { return NodeOutputShape(input, dimensions, keep_reduced_dimensions, dtype); diff --git a/torch_xla/csrc/ops/svd.cpp b/torch_xla/csrc/ops/svd.cpp index f4ef24fa75b4..2b7d38be247b 100644 --- a/torch_xla/csrc/ops/svd.cpp +++ b/torch_xla/csrc/ops/svd.cpp @@ -68,7 +68,7 @@ xla::Shape NodeOutputShape(const Value& input, bool some, bool compute_uv) { } // namespace SVD::SVD(const Value& input, bool some, bool compute_uv) - : Node(ir::OpKind(at::aten::svd), {input}, + : Node(torch::lazy::OpKind(at::aten::svd), {input}, [&]() { return NodeOutputShape(input, some, compute_uv); }, /*num_outputs=*/3, torch::lazy::MHash(some, compute_uv)), some_(some), diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index 055a561a3c05..3a2eb46b45ca 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -46,7 +46,7 @@ xla::Shape NodeOutputShape(const Value& input, bool eigenvectors, bool lower) { } // namespace SymEig::SymEig(const Value& input, bool eigenvectors, bool lower) - : Node(ir::OpKind(at::aten::symeig), {input}, + : Node(torch::lazy::OpKind(at::aten::symeig), {input}, [&]() { return NodeOutputShape(input, eigenvectors, lower); }, /*num_outputs=*/2, torch::lazy::MHash(eigenvectors, lower)), eigenvectors_(eigenvectors), diff --git a/torch_xla/csrc/ops/threshold.cpp b/torch_xla/csrc/ops/threshold.cpp index cec2e47fbb63..b0e305e02d40 100644 --- a/torch_xla/csrc/ops/threshold.cpp +++ b/torch_xla/csrc/ops/threshold.cpp @@ -8,7 +8,7 @@ namespace ir { namespace ops { Threshold::Threshold(const Value& input, float threshold, float value) - : Node(ir::OpKind(at::aten::threshold), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::threshold), {input}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(threshold, value)), threshold_(threshold), value_(value) {} diff --git a/torch_xla/csrc/ops/threshold_backward.cpp b/torch_xla/csrc/ops/threshold_backward.cpp index 928903c5e4a8..e500f4b6170b 100644 --- a/torch_xla/csrc/ops/threshold_backward.cpp +++ b/torch_xla/csrc/ops/threshold_backward.cpp @@ -9,8 +9,9 @@ namespace ops { ThresholdBackward::ThresholdBackward(const Value& grad_output, const Value& input, float threshold) - : Node(ir::OpKind(at::aten::threshold_backward), {grad_output, input}, - input.shape(), /*num_outputs=*/1, torch::lazy::MHash(threshold)), + : Node(torch::lazy::OpKind(at::aten::threshold_backward), + {grad_output, input}, input.shape(), /*num_outputs=*/1, + torch::lazy::MHash(threshold)), threshold_(threshold) {} NodePtr ThresholdBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp index 5ea7907859f0..492827987f99 100644 --- a/torch_xla/csrc/ops/topk.cpp +++ b/torch_xla/csrc/ops/topk.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t k, int64_t dim, TopK::TopK(const Value& input, int64_t k, int64_t dim, bool largest, bool sorted, bool stable) - : Node(ir::OpKind(at::aten::topk), {input}, + : Node(torch::lazy::OpKind(at::aten::topk), {input}, [&]() { return NodeOutputShape(input, k, dim, largest, sorted, stable); }, diff --git a/torch_xla/csrc/ops/triangular_solve.cpp b/torch_xla/csrc/ops/triangular_solve.cpp index 196a4bdae61a..ca107c147a7c 100644 --- a/torch_xla/csrc/ops/triangular_solve.cpp +++ b/torch_xla/csrc/ops/triangular_solve.cpp @@ -76,7 +76,7 @@ xla::Shape NodeOutputShape(const Value& rhs, const Value& lhs) { TriangularSolve::TriangularSolve(const Value& rhs, const Value& lhs, bool left_side, bool lower, bool transpose, bool unit_diagonal) - : Node(ir::OpKind(at::aten::triangular_solve), {rhs, lhs}, + : Node(torch::lazy::OpKind(at::aten::triangular_solve), {rhs, lhs}, [&]() { return NodeOutputShape(rhs, lhs); }, /*num_outputs=*/2, torch::lazy::MHash(left_side, lower, transpose, unit_diagonal)), diff --git a/torch_xla/csrc/ops/tril.cpp b/torch_xla/csrc/ops/tril.cpp index 62fafe98d20b..2ad78438ba09 100644 --- a/torch_xla/csrc/ops/tril.cpp +++ b/torch_xla/csrc/ops/tril.cpp @@ -8,7 +8,7 @@ namespace ir { namespace ops { Tril::Tril(const Value& input, int64_t diagonal) - : Node(ir::OpKind(at::aten::tril), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::tril), {input}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(diagonal)), diagonal_(diagonal) {} diff --git a/torch_xla/csrc/ops/triu.cpp b/torch_xla/csrc/ops/triu.cpp index e82b7ce7153f..9cba3467939a 100644 --- a/torch_xla/csrc/ops/triu.cpp +++ b/torch_xla/csrc/ops/triu.cpp @@ -8,7 +8,7 @@ namespace ir { namespace ops { Triu::Triu(const Value& input, int64_t diagonal) - : Node(ir::OpKind(at::aten::triu), {input}, input.shape(), + : Node(torch::lazy::OpKind(at::aten::triu), {input}, input.shape(), /*num_outputs=*/1, torch::lazy::MHash(diagonal)), diagonal_(diagonal) {} diff --git a/torch_xla/csrc/ops/uniform.cpp b/torch_xla/csrc/ops/uniform.cpp index 25aa31d7f6ec..5df138c18d80 100644 --- a/torch_xla/csrc/ops/uniform.cpp +++ b/torch_xla/csrc/ops/uniform.cpp @@ -12,7 +12,7 @@ namespace ops { Uniform::Uniform(const Value& from, const Value& to, const Value& seed, const xla::Shape& rng_shape) - : Node(ir::OpKind(at::aten::uniform), {from, to, seed}, rng_shape, + : Node(torch::lazy::OpKind(at::aten::uniform), {from, to, seed}, rng_shape, /*num_outputs=*/1, torch::lazy::Hash(rng_shape)) {} NodePtr Uniform::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/unsqueeze.cpp b/torch_xla/csrc/ops/unsqueeze.cpp index 7d821f7f7d0e..31b0d16c9c2e 100644 --- a/torch_xla/csrc/ops/unsqueeze.cpp +++ b/torch_xla/csrc/ops/unsqueeze.cpp @@ -17,7 +17,7 @@ xla::Shape NodeOutputShape(const Value& input, int dim) { } // namespace Unsqueeze::Unsqueeze(const Value& input, int dim) - : Node(ir::OpKind(at::aten::unsqueeze), {input}, + : Node(torch::lazy::OpKind(at::aten::unsqueeze), {input}, [&]() { return NodeOutputShape(input, dim); }, /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/upsample_bilinear2d.cpp b/torch_xla/csrc/ops/upsample_bilinear2d.cpp index e95966c24821..bf05de891ce9 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d.cpp @@ -12,7 +12,7 @@ namespace ops { UpsampleBilinear::UpsampleBilinear(const Value& input, std::vector output_size, bool align_corners) - : Node(ir::OpKind(at::aten::upsample_bilinear2d), {input}, + : Node(torch::lazy::OpKind(at::aten::upsample_bilinear2d), {input}, [&]() { return resize::GetForwardOutputShape2d(input.shape(), output_size); }, diff --git a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp index 2ca7a8550acc..4c1ce7b7feb5 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp @@ -12,7 +12,7 @@ namespace ops { UpsampleBilinearBackward::UpsampleBilinearBackward( const Value& input, std::vector output_size, std::vector input_size, bool align_corners) - : Node(ir::OpKind(at::aten::upsample_bilinear2d_backward), {input}, + : Node(torch::lazy::OpKind(at::aten::upsample_bilinear2d_backward), {input}, [&]() { return resize::GetBackwardOutputShape2d(input.shape(), input_size); }, diff --git a/torch_xla/csrc/ops/upsample_nearest2d.cpp b/torch_xla/csrc/ops/upsample_nearest2d.cpp index 4c26a42a746d..92dfd8363f96 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d.cpp @@ -11,7 +11,7 @@ namespace ops { UpsampleNearest::UpsampleNearest(const Value& input, std::vector output_size) - : Node(ir::OpKind(at::aten::upsample_nearest2d), {input}, + : Node(torch::lazy::OpKind(at::aten::upsample_nearest2d), {input}, [&]() { return resize::GetForwardOutputShape2d(input.shape(), output_size); }, diff --git a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp index dcebda0e6907..edbff40e0492 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp @@ -12,7 +12,7 @@ namespace ops { UpsampleNearestBackward::UpsampleNearestBackward( const Value& input, std::vector output_size, std::vector input_size) - : Node(ir::OpKind(at::aten::upsample_nearest2d_backward), {input}, + : Node(torch::lazy::OpKind(at::aten::upsample_nearest2d_backward), {input}, [&]() { return resize::GetBackwardOutputShape2d(input.shape(), input_size); }, diff --git a/torch_xla/csrc/ops/user_computation.cpp b/torch_xla/csrc/ops/user_computation.cpp index 5ce0c0a0ec21..50491012b55c 100644 --- a/torch_xla/csrc/ops/user_computation.cpp +++ b/torch_xla/csrc/ops/user_computation.cpp @@ -13,7 +13,7 @@ size_t GetNumOutputs(const xla::Shape& shape) { } // namespace -UserComputation::UserComputation(OpKind op, OpList operands, +UserComputation::UserComputation(torch::lazy::OpKind op, OpList operands, ComputationPtr computation) : Node(std::move(op), operands, computation->program_shape().result(), GetNumOutputs(computation->program_shape().result()), @@ -21,7 +21,7 @@ UserComputation::UserComputation(OpKind op, OpList operands, computation_(std::move(computation)) {} NodePtr UserComputation::Clone(OpList operands) const { - return MakeNode(op(), operands, computation_); + return ir::MakeNode(op(), operands, computation_); } XlaOpVector UserComputation::Lower(LoweringContext* loctx) const { diff --git a/torch_xla/csrc/ops/user_computation.h b/torch_xla/csrc/ops/user_computation.h index 84fd55d9c80d..71b351061c06 100644 --- a/torch_xla/csrc/ops/user_computation.h +++ b/torch_xla/csrc/ops/user_computation.h @@ -9,7 +9,8 @@ namespace ops { class UserComputation : public Node { public: - UserComputation(OpKind op, OpList operands, ComputationPtr computation); + UserComputation(torch::lazy::OpKind op, OpList operands, + ComputationPtr computation); NodePtr Clone(OpList operands) const override; diff --git a/torch_xla/csrc/ops/var.cpp b/torch_xla/csrc/ops/var.cpp index 2e73068fca38..8b6de38e5d2f 100644 --- a/torch_xla/csrc/ops/var.cpp +++ b/torch_xla/csrc/ops/var.cpp @@ -27,7 +27,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, Var::Var(const Value& input, std::vector dimensions, int64_t correction, bool keep_reduced_dimensions) - : Node(ir::OpKind(at::aten::var), {input}, + : Node(torch::lazy::OpKind(at::aten::var), {input}, NodeOutputShape(input, dimensions, correction, keep_reduced_dimensions), /*num_outputs=*/1, diff --git a/torch_xla/csrc/ops/var_mean.cpp b/torch_xla/csrc/ops/var_mean.cpp index 01622dcff0fb..8bd718c5caa6 100644 --- a/torch_xla/csrc/ops/var_mean.cpp +++ b/torch_xla/csrc/ops/var_mean.cpp @@ -30,7 +30,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector& dimensions, VarMean::VarMean(const Value& input, std::vector dimensions, int64_t correction, bool keep_reduced_dimensions) - : Node(ir::OpKind(at::aten::var_mean), {input}, + : Node(torch::lazy::OpKind(at::aten::var_mean), {input}, [&]() { return NodeOutputShape(input, dimensions, correction, keep_reduced_dimensions); diff --git a/torch_xla/csrc/ops/view.cpp b/torch_xla/csrc/ops/view.cpp index 77028c0d11f8..e37c79a49966 100644 --- a/torch_xla/csrc/ops/view.cpp +++ b/torch_xla/csrc/ops/view.cpp @@ -27,7 +27,7 @@ xla::Shape NodeOutputShape(const Value& input, } // namespace View::View(const Value& input, std::vector output_size) - : Node(ir::OpKind(at::aten::view), {input}, + : Node(torch::lazy::OpKind(at::aten::view), {input}, NodeOutputShape(input, output_size), /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} diff --git a/torch_xla/csrc/ops/xla_ops.h b/torch_xla/csrc/ops/xla_ops.h index 41c682224bda..8ec8c65d21d0 100644 --- a/torch_xla/csrc/ops/xla_ops.h +++ b/torch_xla/csrc/ops/xla_ops.h @@ -13,18 +13,19 @@ class OpKindWrapper { public: OpKindWrapper(const char* name) : name_(name) {} - const OpKind& operator*() const { return get(); } + const torch::lazy::OpKind& operator*() const { return get(); } - operator OpKind() const { return get(); } + operator torch::lazy::OpKind() const { return get(); } private: - const OpKind& get() const { - std::call_once(once_, [this]() { op_kind_ = OpKind::Get(name_); }); + const torch::lazy::OpKind& get() const { + std::call_once(once_, + [this]() { op_kind_ = torch::lazy::OpKind::Get(name_); }); return op_kind_; } const char* name_; - mutable OpKind op_kind_; + mutable torch::lazy::OpKind op_kind_; mutable std::once_flag once_; }; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 1990f81cab11..3c1d0d03610a 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -721,7 +721,8 @@ class XLATensor { static XLATensor log(const XLATensor& input); - static XLATensor log_base(const XLATensor& input, ir::OpKind op, double base); + static XLATensor log_base(const XLATensor& input, torch::lazy::OpKind op, + double base); static XLATensor log_sigmoid(const XLATensor& input); static std::tuple log_sigmoid_forward( diff --git a/torch_xla/csrc/tensor_methods.cpp b/torch_xla/csrc/tensor_methods.cpp index 2a8ce2fe5e8f..d911711fe602 100644 --- a/torch_xla/csrc/tensor_methods.cpp +++ b/torch_xla/csrc/tensor_methods.cpp @@ -504,7 +504,7 @@ std::vector XLATensor::user_computation( input_values.push_back(input.GetIrValue()); } ir::NodePtr node = ir::MakeNode( - ir::OpKind::Get(opname), input_values, std::move(computation)); + torch::lazy::OpKind::Get(opname), input_values, std::move(computation)); // Cast can be one of the user computation and we don't want to inherit the // logical_element_type in this case return inputs.front().MakeOutputTensors(node, /*inherit_logical_type=*/false); @@ -1586,7 +1586,7 @@ XLATensor XLATensor::hardshrink_backward(const XLATensor& grad_out, const XLATensor& input, const at::Scalar& lambda) { return input.CreateFrom(ir::MakeNode( - ir::OpKind(at::aten::hardshrink_backward), grad_out.GetIrValue(), + torch::lazy::OpKind(at::aten::hardshrink_backward), grad_out.GetIrValue(), input.GetIrValue(), lambda)); } @@ -1655,7 +1655,7 @@ XLATensor XLATensor::log(const XLATensor& input) { c10::nullopt); } -XLATensor XLATensor::log_base(const XLATensor& input, ir::OpKind op, +XLATensor XLATensor::log_base(const XLATensor& input, torch::lazy::OpKind op, double base) { // Here we explictly pass c10::nullopt as logical_element_type because // otherwise result will inherit the input's logical_element_type. In the @@ -2568,7 +2568,7 @@ XLATensor XLATensor::softshrink_backward(const XLATensor& grad_out, const XLATensor& input, const at::Scalar& lambda) { return input.CreateFrom(ir::MakeNode( - ir::OpKind(at::aten::softshrink_backward), grad_out.GetIrValue(), + torch::lazy::OpKind(at::aten::softshrink_backward), grad_out.GetIrValue(), input.GetIrValue(), lambda)); }