Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1821,7 +1821,7 @@ at::Tensor XLANativeFunctions::log(const at::Tensor& self) {
at::Tensor XLANativeFunctions::log10(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::log_base(
bridge::GetXlaTensor(self), ir::OpKind(at::aten::log10), 10.0));
bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log10), 10.0));
}

at::Tensor XLANativeFunctions::log1p(const at::Tensor& self) {
Expand All @@ -1833,7 +1833,7 @@ at::Tensor XLANativeFunctions::log1p(const at::Tensor& self) {
at::Tensor XLANativeFunctions::log2(const at::Tensor& self) {
XLA_FN_COUNTER("xla::");
return bridge::AtenFromXlaTensor(XLATensor::log_base(
bridge::GetXlaTensor(self), ir::OpKind(at::aten::log2), 2.0));
bridge::GetXlaTensor(self), torch::lazy::OpKind(at::aten::log2), 2.0));
}

at::Tensor XLANativeFunctions::log_sigmoid_backward(
Expand Down
19 changes: 6 additions & 13 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,8 @@ torch::lazy::hash_t Value::hash() const {
return torch::lazy::HashCombine(node->hash(), index);
}

OpKind OpKind::Get(const std::string& name) {
return OpKind(c10::Symbol::fromQualString(name));
}

torch::lazy::hash_t OpKind::hash() const {
return torch::lazy::StringHash(op.toQualString());
}

Node::Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs,
torch::lazy::hash_t hash_seed)
Node::Node(torch::lazy::OpKind op, OpList operands, xla::Shape shape,
size_t num_outputs, torch::lazy::hash_t hash_seed)
: op_(std::move(op)),
num_outputs_(num_outputs),
shape_(std::move(shape)),
Expand All @@ -150,7 +142,7 @@ Node::Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs,
}
}

Node::Node(OpKind op, OpList operands,
Node::Node(torch::lazy::OpKind op, OpList operands,
const std::function<xla::Shape()>& shape_fn, size_t num_outputs,
torch::lazy::hash_t hash_seed)
: Node(std::move(op), operands, xla::Shape(), num_outputs, hash_seed) {
Expand All @@ -159,7 +151,7 @@ Node::Node(OpKind op, OpList operands,
shape_ = GetOpShape(shape_fn);
}

Node::Node(OpKind op, xla::Shape shape, size_t num_outputs,
Node::Node(torch::lazy::OpKind op, xla::Shape shape, size_t num_outputs,
torch::lazy::hash_t hash_seed)
: op_(std::move(op)),
num_outputs_(num_outputs),
Expand Down Expand Up @@ -247,7 +239,8 @@ XlaOpVector Node::Lower(LoweringContext* loctx) const {
XLA_ERROR() << "Lowering not implemented for node: " << *this;
}

torch::lazy::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape,
torch::lazy::hash_t Node::GetOpHash(torch::lazy::OpKind op,
const xla::Shape& shape,
torch::lazy::hash_t hash_seed) {
torch::lazy::hash_t h =
torch::lazy::HashCombine(op.hash(), torch::lazy::Hash(shape.ToString()));
Expand Down
47 changes: 11 additions & 36 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "tensorflow/compiler/xla/xla_client/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "torch/csrc/lazy/core/hash.h"
#include "torch/csrc/lazy/core/ir.h"
#include "torch_xla/csrc/python_util.h"

namespace torch_xla {
Expand Down Expand Up @@ -134,34 +135,6 @@ struct Value {
size_t index = 0;
};

// The Kind of operation a Node can be associated to.
struct OpKind {
OpKind() = default;
explicit OpKind(c10::Symbol op) : op(std::move(op)) {}

bool operator==(const OpKind& rhs) const { return op == rhs.op; }
bool operator!=(const OpKind& rhs) const { return !operator==(rhs); }
bool operator<(const OpKind& rhs) const {
return c10::unique_t(op) < c10::unique_t(rhs.op);
}

torch::lazy::hash_t hash() const;

std::string ToString() const { return op.toQualString(); }

// Retrieves an existing operation object, or creates a new one. Operations
// that are specific to the XLA side, should live within the 'xla::'
// namespace.
static OpKind Get(const std::string& name);

c10::Symbol op;
};

inline std::ostream& operator<<(std::ostream& stream, const OpKind& op) {
stream << op.ToString();
return stream;
}

using OpList = absl::Span<const Value>;

// A node in the graph. Nodes for operations which requires extra data to be
Expand All @@ -175,22 +148,23 @@ class Node {
// Creates a new node with the given op name. The op is a unique identifier
// for the operation. The num_outputs tells how many outputs a given operation
// generates.
Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs = 1,
Node(torch::lazy::OpKind op, OpList operands, xla::Shape shape,
size_t num_outputs = 1,
torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9);

// Same as the constructor above, but the shape is generated by a function,
// only if needed (shape cache miss).
Node(OpKind op, OpList operands, const std::function<xla::Shape()>& shape_fn,
size_t num_outputs = 1,
Node(torch::lazy::OpKind op, OpList operands,
const std::function<xla::Shape()>& shape_fn, size_t num_outputs = 1,
torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9);

// Contructor used to create leaf nodes.
Node(OpKind op, xla::Shape shape, size_t num_outputs,
Node(torch::lazy::OpKind op, xla::Shape shape, size_t num_outputs,
torch::lazy::hash_t hash_seed);

virtual ~Node();

const OpKind& op() const { return op_; }
const torch::lazy::OpKind& op() const { return op_; }

size_t num_outputs() const { return num_outputs_; }

Expand Down Expand Up @@ -247,13 +221,14 @@ class Node {

xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;

static torch::lazy::hash_t GetOpHash(OpKind op, const xla::Shape& shape,
static torch::lazy::hash_t GetOpHash(torch::lazy::OpKind op,
const xla::Shape& shape,
torch::lazy::hash_t hash_seed);

static std::vector<SourceLocation> GetFrameInfo();

// The ID of the operation captured by this node.
OpKind op_;
torch::lazy::OpKind op_;
size_t num_outputs_ = 1;
xla::Shape shape_;
// A node holds a real reference to its operands.
Expand Down Expand Up @@ -295,7 +270,7 @@ NodePtr MakeNode(Args&&... args) {
}

template <typename T>
T* NodeCast(const Node* node, OpKind op) {
T* NodeCast(const Node* node, torch::lazy::OpKind op) {
if (op != node->op()) {
return nullptr;
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/adaptive_avg_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input,

AdaptiveAvgPool2d::AdaptiveAvgPool2d(const Value& input,
std::vector<int64_t> output_size)
: Node(ir::OpKind(at::aten::adaptive_avg_pool2d), {input},
: Node(torch::lazy::OpKind(at::aten::adaptive_avg_pool2d), {input},
[&]() { return NodeOutputShape(input, output_size); },
/*num_outputs=*/1, torch::lazy::MHash(output_size)),
output_size_(std::move(output_size)) {}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/adaptive_avg_pool3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input,

AdaptiveAvgPool3d::AdaptiveAvgPool3d(const Value& input,
std::vector<int64_t> output_size)
: Node(ir::OpKind(at::aten::adaptive_avg_pool3d), {input},
: Node(torch::lazy::OpKind(at::aten::adaptive_avg_pool3d), {input},
[&]() { return NodeOutputShape(input, output_size); },
/*num_outputs=*/1, torch::lazy::MHash(output_size)),
output_size_(std::move(output_size)) {}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/adaptive_max_pool2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ xla::Shape NodeOutputShape(const Value& input,

AdaptiveMaxPool2d::AdaptiveMaxPool2d(const Value& input,
std::vector<int64_t> output_size)
: Node(ir::OpKind(at::aten::adaptive_max_pool2d), {input},
: Node(torch::lazy::OpKind(at::aten::adaptive_max_pool2d), {input},
[&]() { return NodeOutputShape(input, output_size); },
/*num_outputs=*/2, torch::lazy::MHash(output_size)),
output_size_(std::move(output_size)) {}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/all.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector<int64_t>& dimensions,

All::All(const Value& input, std::vector<int64_t> dimensions,
bool keep_reduced_dimensions)
: Node(ir::OpKind(at::aten::all), {input},
: Node(torch::lazy::OpKind(at::aten::all), {input},
NodeOutputShape(input, dimensions, keep_reduced_dimensions),
/*num_outputs=*/1,
torch::lazy::MHash(dimensions, keep_reduced_dimensions)),
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/amax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector<int64_t>& dimensions,
} // namespace

Amax::Amax(const Value& input, std::vector<int64_t> dimensions, bool keepdim)
: Node(ir::OpKind(at::aten::amax), {input},
: Node(torch::lazy::OpKind(at::aten::amax), {input},
[&]() { return NodeOutputShape(input, dimensions, keepdim); },
/*num_outputs=*/1, torch::lazy::MHash(dimensions, keepdim)),
dimensions_(std::move(dimensions)),
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/amin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector<int64_t>& dimensions,
} // namespace

Amin::Amin(const Value& input, std::vector<int64_t> dimensions, bool keepdim)
: Node(ir::OpKind(at::aten::amin), {input},
: Node(torch::lazy::OpKind(at::aten::amin), {input},
[&]() { return NodeOutputShape(input, dimensions, keepdim); },
/*num_outputs=*/1, torch::lazy::MHash(dimensions, keepdim)),
dimensions_(std::move(dimensions)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ std::vector<Value> GetOperandList(absl::Span<const Value> operands,

AmpForachNonFiniteCheckAndUnscale::AmpForachNonFiniteCheckAndUnscale(
const OpList& inputs, const Value& found_inf, const Value& inv_scale)
: Node(ir::OpKind(at::aten::_amp_foreach_non_finite_check_and_unscale_),
: Node(torch::lazy::OpKind(
at::aten::_amp_foreach_non_finite_check_and_unscale_),
GetOperandList(inputs, found_inf, inv_scale),
NodeOutputShape(inputs, found_inf),
/*num_outputs=*/inputs.size() + 1) {}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/amp_update_scale.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ AmpUpdateScale::AmpUpdateScale(const Value& current_scale,
const Value& found_inf,
double scale_growth_factor,
double scale_backoff_factor, int growth_interval)
: Node(ir::OpKind(at::aten::_amp_update_scale_),
: Node(torch::lazy::OpKind(at::aten::_amp_update_scale_),
{current_scale, growth_tracker, found_inf},
NodeOutputShape(growth_tracker, current_scale),
/*num_outputs=*/2),
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/any.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, std::vector<int64_t>& dimensions,

Any::Any(const Value& input, std::vector<int64_t> dimensions,
bool keep_reduced_dimensions)
: Node(ir::OpKind(at::aten::any), {input},
: Node(torch::lazy::OpKind(at::aten::any), {input},
NodeOutputShape(input, dimensions, keep_reduced_dimensions),
/*num_outputs=*/1,
torch::lazy::MHash(dimensions, keep_reduced_dimensions)),
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/arg_max.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t dim, bool keepdim) {
} // namespace

ArgMax::ArgMax(const Value& input, int64_t dim, bool keepdim)
: Node(ir::OpKind(at::aten::argmax), {input},
: Node(torch::lazy::OpKind(at::aten::argmax), {input},
[&]() { return NodeOutputShape(input, dim, keepdim); },
/*num_outputs=*/1, torch::lazy::MHash(dim, keepdim)),
dim_(dim),
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/arg_min.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ xla::Shape NodeOutputShape(const Value& input, int64_t dim, bool keepdim) {
} // namespace

ArgMin::ArgMin(const Value& input, int64_t dim, bool keepdim)
: Node(ir::OpKind(at::aten::argmin), {input},
: Node(torch::lazy::OpKind(at::aten::argmin), {input},
[&]() { return NodeOutputShape(input, dim, keepdim); },
/*num_outputs=*/1, torch::lazy::MHash(dim, keepdim)),
dim_(dim),
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/ops/arithmetic_ir_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ NodePtr operator+(const Value& node1, const Value& node2) {
return node.ReturnOp(XlaHelpers::PromotedAdd(op0, op1), loctx);
};
return ops::GenericOp(
OpKind(at::aten::add), {node1, node2},
torch::lazy::OpKind(at::aten::add), {node1, node2},
XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()),
std::move(lower_fn));
}
Expand All @@ -28,7 +28,7 @@ NodePtr operator-(const Value& node1, const Value& node2) {
return node.ReturnOp(XlaHelpers::PromotedSub(op0, op1), loctx);
};
return ops::GenericOp(
OpKind(at::aten::sub), {node1, node2},
torch::lazy::OpKind(at::aten::sub), {node1, node2},
XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()),
std::move(lower_fn));
}
Expand All @@ -40,7 +40,7 @@ NodePtr operator*(const Value& node1, const Value& node2) {
return node.ReturnOp(XlaHelpers::PromotedMul(op0, op1), loctx);
};
return ops::GenericOp(
OpKind(at::aten::mul), {node1, node2},
torch::lazy::OpKind(at::aten::mul), {node1, node2},
XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()),
std::move(lower_fn));
}
Expand All @@ -52,7 +52,7 @@ NodePtr operator/(const Value& node1, const Value& node2) {
return node.ReturnOp(XlaHelpers::PromotedDiv(op0, op1), loctx);
};
return ops::GenericOp(
OpKind(at::aten::div), {node1, node2},
torch::lazy::OpKind(at::aten::div), {node1, node2},
XlaHelpers::GetPromotedBinaryOpShape(node1.shape(), node2.shape()),
std::move(lower_fn));
}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/as_strided.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ xla::XlaOp LowerAsStrided(xla::XlaOp input, absl::Span<const int64_t> size,

AsStrided::AsStrided(const Value& input, std::vector<int64_t> size,
std::vector<int64_t> stride, int64_t storage_offset)
: Node(ir::OpKind(at::aten::as_strided), {input},
: Node(torch::lazy::OpKind(at::aten::as_strided), {input},
[&]() {
return xla::ShapeUtil::MakeShape(input.shape().element_type(),
size);
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/avg_pool_nd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ AvgPoolNd::AvgPoolNd(const Value& input, int64_t spatial_dim_count,
std::vector<int64_t> kernel_size,
std::vector<int64_t> stride, std::vector<int64_t> padding,
bool ceil_mode, bool count_include_pad)
: Node(ir::OpKind(AvgPoolNdSymbol(spatial_dim_count)), {input},
: Node(torch::lazy::OpKind(AvgPoolNdSymbol(spatial_dim_count)), {input},
[&]() {
return NodeOutputShape(input, spatial_dim_count, kernel_size,
stride, padding, ceil_mode,
Expand Down
3 changes: 2 additions & 1 deletion torch_xla/csrc/ops/avg_pool_nd_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ AvgPoolNdBackward::AvgPoolNdBackward(
const Value& grad_output, const Value& input, int64_t spatial_dim_count,
std::vector<int64_t> kernel_size, std::vector<int64_t> stride,
std::vector<int64_t> padding, bool ceil_mode, bool count_include_pad)
: Node(OpKind(AvgNdBackwardSymbol(spatial_dim_count)), {grad_output, input},
: Node(torch::lazy::OpKind(AvgNdBackwardSymbol(spatial_dim_count)),
{grad_output, input},
[&]() {
return NodeOutputShape(grad_output, input, spatial_dim_count,
kernel_size, stride, padding, ceil_mode,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/bernoulli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace ops {

Bernoulli::Bernoulli(const Value& probability, const Value& seed,
xla::Shape shape)
: Node(ir::OpKind(at::aten::bernoulli), {probability, seed},
: Node(torch::lazy::OpKind(at::aten::bernoulli), {probability, seed},
std::move(shape)) {}

NodePtr Bernoulli::Clone(OpList operands) const {
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/binary_cross_entropy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ xla::Shape NodeOutputShape(const Value& logits, const Value& labels,
BinaryCrossEntropy::BinaryCrossEntropy(const Value& logits, const Value& labels,
const absl::optional<Value>& weight,
ReductionMode reduction)
: Node(ir::OpKind(at::aten::binary_cross_entropy),
: Node(torch::lazy::OpKind(at::aten::binary_cross_entropy),
xla::util::GetValuesVector<Value>({logits, labels}, {&weight}),
[&]() { return NodeOutputShape(logits, labels, weight, reduction); },
/*num_outputs=*/1,
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/binary_cross_entropy_backward.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ xla::Shape NodeOutputShape(const Value& grad_output, const Value& logits,
BinaryCrossEntropyBackward::BinaryCrossEntropyBackward(
const Value& grad_output, const Value& logits, const Value& labels,
const absl::optional<Value>& weight, ReductionMode reduction)
: Node(ir::OpKind(at::aten::binary_cross_entropy_backward),
: Node(torch::lazy::OpKind(at::aten::binary_cross_entropy_backward),
xla::util::GetValuesVector<Value>({grad_output, logits, labels},
{&weight}),
[&]() {
Expand Down
6 changes: 3 additions & 3 deletions torch_xla/csrc/ops/bitwise_ir_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Value BitwiseAnd(const Value& node1, const Value& node2) {
[](xla::XlaOp lhs, xla::XlaOp rhs) { return lhs & rhs; });
};
return GenericOp(
OpKind(at::aten::__and__), {node1, node2},
torch::lazy::OpKind(at::aten::__and__), {node1, node2},
[&]() {
return InferOutputShape({node1.shape(), node2.shape()}, shape_fn);
},
Expand All @@ -44,7 +44,7 @@ Value BitwiseOr(const Value& node1, const Value& node2) {
[](xla::XlaOp lhs, xla::XlaOp rhs) { return lhs | rhs; });
};
return GenericOp(
OpKind(at::aten::__or__), {node1, node2},
torch::lazy::OpKind(at::aten::__or__), {node1, node2},
[&]() {
return InferOutputShape({node1.shape(), node2.shape()}, shape_fn);
},
Expand All @@ -65,7 +65,7 @@ Value BitwiseXor(const Value& node1, const Value& node2) {
[](xla::XlaOp lhs, xla::XlaOp rhs) { return lhs ^ rhs; });
};
return GenericOp(
OpKind(at::aten::__xor__), {node1, node2},
torch::lazy::OpKind(at::aten::__xor__), {node1, node2},
[&]() {
return InferOutputShape({node1.shape(), node2.shape()}, shape_fn);
},
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/cat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(absl::Span<const ir::Value> values, int64_t dim) {
} // namespace

Cat::Cat(absl::Span<const ir::Value> values, int64_t dim)
: Node(ir::OpKind(at::aten::cat), values,
: Node(torch::lazy::OpKind(at::aten::cat), values,
[&]() { return NodeOutputShape(values, dim); },
/*num_outputs=*/1, torch::lazy::MHash(dim)),
dim_(dim) {}
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/cholesky.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace ir {
namespace ops {

Cholesky::Cholesky(const Value& input, bool lower)
: Node(ir::OpKind(at::aten::cholesky), {input}, input.shape(),
: Node(torch::lazy::OpKind(at::aten::cholesky), {input}, input.shape(),
/*num_outputs=*/1, torch::lazy::MHash(lower)),
lower_(lower) {}

Expand Down
Loading