From 8e63259f4a157425c7036a72b72f8c30936c06b3 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Fri, 1 Oct 2021 14:32:41 -0700 Subject: [PATCH 1/9] Replace Usages of xla:util::Hash* with torch::lazy::Hash* As part of migration of core lazy tensor functionality to PyTorch core, this uses the newly added torch::lazy::Hash functions from PyTorch. Note: while the Hash* functions are largely identical to the original XLA ones, the underlying uint128 class is from protobuf instead of absl, since it was a slightly smaller dependency to ingest and get building on multiple OS/platform combinations for PyTorch. --- torch_xla/csrc/computation.cpp | 2 +- torch_xla/csrc/computation.h | 5 ++- torch_xla/csrc/debug_util.cpp | 5 ++- torch_xla/csrc/device.h | 5 ++- torch_xla/csrc/ir.cpp | 37 +++++++++-------- torch_xla/csrc/ir.h | 25 +++++------ torch_xla/csrc/op_by_op_executor.cpp | 34 ++++++++------- torch_xla/csrc/op_by_op_executor.h | 4 +- torch_xla/csrc/ops/adaptive_avg_pool2d.cpp | 2 +- torch_xla/csrc/ops/adaptive_avg_pool3d.cpp | 2 +- torch_xla/csrc/ops/adaptive_max_pool2d.cpp | 2 +- torch_xla/csrc/ops/all.cpp | 2 +- torch_xla/csrc/ops/all_reduce.cpp | 2 +- torch_xla/csrc/ops/all_to_all.cpp | 2 +- torch_xla/csrc/ops/amax.cpp | 2 +- torch_xla/csrc/ops/amin.cpp | 2 +- torch_xla/csrc/ops/any.cpp | 2 +- torch_xla/csrc/ops/arg_max.cpp | 2 +- torch_xla/csrc/ops/arg_min.cpp | 2 +- torch_xla/csrc/ops/as_strided.cpp | 2 +- torch_xla/csrc/ops/as_strided_view_update.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/binary_cross_entropy.cpp | 2 +- .../ops/binary_cross_entropy_backward.cpp | 2 +- torch_xla/csrc/ops/cast.cpp | 4 +- torch_xla/csrc/ops/cat.cpp | 2 +- torch_xla/csrc/ops/cholesky.cpp | 2 +- torch_xla/csrc/ops/collective_permute.cpp | 2 +- torch_xla/csrc/ops/constant_pad_nd.cpp | 2 +- .../ops/convolution_backward_overrideable.cpp | 2 +- .../csrc/ops/convolution_overrideable.cpp | 4 +- torch_xla/csrc/ops/cumprod.cpp | 2 +- torch_xla/csrc/ops/cumsum.cpp | 2 +- torch_xla/csrc/ops/device_data.cpp | 2 +- torch_xla/csrc/ops/diagonal.cpp | 2 +- torch_xla/csrc/ops/diagonal_view_update.cpp | 2 +- torch_xla/csrc/ops/discrete_uniform.cpp | 3 +- torch_xla/csrc/ops/expand.cpp | 2 +- torch_xla/csrc/ops/flip.cpp | 2 +- torch_xla/csrc/ops/gather.cpp | 2 +- torch_xla/csrc/ops/generic.cpp | 6 +-- torch_xla/csrc/ops/generic.h | 8 ++-- torch_xla/csrc/ops/generic_slice.cpp | 3 +- torch_xla/csrc/ops/get_dimensions_size.cpp | 2 +- torch_xla/csrc/ops/hardtanh_backward.cpp | 2 +- torch_xla/csrc/ops/index_get.cpp | 2 +- torch_xla/csrc/ops/index_ops.cpp | 2 +- torch_xla/csrc/ops/index_put.cpp | 2 +- torch_xla/csrc/ops/index_select.cpp | 2 +- torch_xla/csrc/ops/kth_value.cpp | 2 +- torch_xla/csrc/ops/l1_loss.cpp | 2 +- torch_xla/csrc/ops/l1_loss_backward.cpp | 2 +- torch_xla/csrc/ops/leaky_relu.cpp | 2 +- torch_xla/csrc/ops/leaky_relu_backward.cpp | 2 +- torch_xla/csrc/ops/linear_interpolation.cpp | 2 +- torch_xla/csrc/ops/log_softmax.cpp | 2 +- torch_xla/csrc/ops/log_softmax_backward.cpp | 2 +- torch_xla/csrc/ops/logsumexp.cpp | 2 +- torch_xla/csrc/ops/max_in_dim.cpp | 2 +- torch_xla/csrc/ops/max_pool_nd.cpp | 2 +- torch_xla/csrc/ops/max_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/max_unpool_nd.cpp | 2 +- torch_xla/csrc/ops/max_unpool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/mean.cpp | 2 +- torch_xla/csrc/ops/min_in_dim.cpp | 2 +- torch_xla/csrc/ops/mse_loss.cpp | 2 +- torch_xla/csrc/ops/mse_loss_backward.cpp | 2 +- .../csrc/ops/native_batch_norm_backward.cpp | 2 +- .../csrc/ops/native_batch_norm_forward.cpp | 2 +- torch_xla/csrc/ops/nll_loss.cpp | 2 +- torch_xla/csrc/ops/nll_loss2d.cpp | 2 +- torch_xla/csrc/ops/nll_loss2d_backward.cpp | 2 +- torch_xla/csrc/ops/nll_loss_backward.cpp | 2 +- torch_xla/csrc/ops/nms.cpp | 2 +- torch_xla/csrc/ops/not_supported.cpp | 2 +- torch_xla/csrc/ops/ops.cpp | 4 +- torch_xla/csrc/ops/ops.h | 6 +-- torch_xla/csrc/ops/permute.cpp | 2 +- torch_xla/csrc/ops/prod.cpp | 2 +- torch_xla/csrc/ops/put.cpp | 2 +- torch_xla/csrc/ops/qr.cpp | 2 +- torch_xla/csrc/ops/reflection_pad2d.cpp | 2 +- .../csrc/ops/reflection_pad2d_backward.cpp | 2 +- torch_xla/csrc/ops/repeat.cpp | 2 +- torch_xla/csrc/ops/replication_pad.cpp | 2 +- .../csrc/ops/replication_pad_backward.cpp | 2 +- torch_xla/csrc/ops/resize.cpp | 2 +- torch_xla/csrc/ops/rrelu_with_noise.cpp | 2 +- .../csrc/ops/rrelu_with_noise_backward.cpp | 2 +- torch_xla/csrc/ops/scalar.cpp | 6 +-- torch_xla/csrc/ops/scalar.h | 2 +- torch_xla/csrc/ops/scatter.cpp | 2 +- torch_xla/csrc/ops/scatter_add.cpp | 2 +- torch_xla/csrc/ops/select.cpp | 2 +- torch_xla/csrc/ops/softmax.cpp | 2 +- torch_xla/csrc/ops/softmax_backward.cpp | 2 +- torch_xla/csrc/ops/split.cpp | 2 +- torch_xla/csrc/ops/squeeze.cpp | 2 +- torch_xla/csrc/ops/stack.cpp | 2 +- torch_xla/csrc/ops/std.cpp | 2 +- torch_xla/csrc/ops/std_mean.cpp | 2 +- torch_xla/csrc/ops/sum.cpp | 2 +- torch_xla/csrc/ops/svd.cpp | 2 +- torch_xla/csrc/ops/symeig.cpp | 2 +- torch_xla/csrc/ops/threshold.cpp | 2 +- torch_xla/csrc/ops/threshold_backward.cpp | 2 +- torch_xla/csrc/ops/topk.cpp | 2 +- torch_xla/csrc/ops/triangular_solve.cpp | 2 +- torch_xla/csrc/ops/tril.cpp | 2 +- torch_xla/csrc/ops/triu.cpp | 2 +- torch_xla/csrc/ops/uniform.cpp | 3 +- torch_xla/csrc/ops/unselect.cpp | 2 +- torch_xla/csrc/ops/unsqueeze.cpp | 2 +- torch_xla/csrc/ops/update_slice.cpp | 3 +- torch_xla/csrc/ops/upsample_bilinear2d.cpp | 2 +- .../csrc/ops/upsample_bilinear2d_backward.cpp | 2 +- torch_xla/csrc/ops/upsample_nearest2d.cpp | 2 +- .../csrc/ops/upsample_nearest2d_backward.cpp | 2 +- torch_xla/csrc/ops/var.cpp | 2 +- torch_xla/csrc/ops/var_mean.cpp | 2 +- torch_xla/csrc/ops/view.cpp | 2 +- torch_xla/csrc/tensor.cpp | 41 ++++++++++--------- torch_xla/csrc/tensor.h | 6 +-- torch_xla/csrc/tensor_util.cpp | 27 ++++++------ torch_xla/csrc/tensor_util.h | 3 +- torch_xla/csrc/torch_util.cpp | 11 +++++ torch_xla/csrc/torch_util.h | 15 +++++++ 128 files changed, 259 insertions(+), 217 deletions(-) diff --git a/torch_xla/csrc/computation.cpp b/torch_xla/csrc/computation.cpp index 81600c613fb..5a83e06be07 100644 --- a/torch_xla/csrc/computation.cpp +++ b/torch_xla/csrc/computation.cpp @@ -8,7 +8,7 @@ namespace torch_xla { Computation::Computation(std::string name, xla::XlaComputation computation) : name_(std::move(name)), computation_(std::move(computation)) { program_shape_ = ConsumeValue(computation_.GetProgramShape()); - hash_ = xla::util::MHash(name_, computation_.proto().SerializeAsString()); + hash_ = torch::lazy::MHash(name_, computation_.proto().SerializeAsString()); } } // namespace torch_xla diff --git a/torch_xla/csrc/computation.h b/torch_xla/csrc/computation.h index b93634c4c2d..a02a1b95fe3 100644 --- a/torch_xla/csrc/computation.h +++ b/torch_xla/csrc/computation.h @@ -6,6 +6,7 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_client/types.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { @@ -19,13 +20,13 @@ class Computation { const xla::ProgramShape& program_shape() const { return program_shape_; } - const xla::hash_t& hash() const { return hash_; } + const torch::lazy::hash_t& hash() const { return hash_; } private: std::string name_; xla::XlaComputation computation_; xla::ProgramShape program_shape_; - xla::hash_t hash_; + torch::lazy::hash_t hash_; }; using ComputationPtr = std::shared_ptr; diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 4d4bea895be..8820499a3d3 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -15,6 +15,7 @@ #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/ir_util.h" #include "torch_xla/csrc/python_util.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace { @@ -55,7 +56,7 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span tensors, GraphFormat format) { std::vector root_nodes; std::vector root_values; - std::vector root_hashes; + std::vector root_hashes; xla::util::Unique unique_device; if (indices != nullptr) { for (auto index : *indices) { @@ -91,7 +92,7 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span tensors, if (i > 0) { ss << ", "; } - ss << xla::util::HexHash(root_hashes[i]); + ss << torch::lazy::HashToString(root_hashes[i]); } ss << ")\n"; diff --git a/torch_xla/csrc/device.h b/torch_xla/csrc/device.h index 52e11f2c293..32e762fc942 100644 --- a/torch_xla/csrc/device.h +++ b/torch_xla/csrc/device.h @@ -4,6 +4,7 @@ #include #include "tensorflow/compiler/xla/xla_client/util.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { @@ -36,8 +37,8 @@ struct Device { } size_t hash() const { - return xla::util::StdHashCombine(xla::util::GetEnumValue(hw_type), - ordinal + 1); + return torch::lazy::StdHashCombine(xla::util::GetEnumValue(hw_type), + ordinal + 1); } DeviceType hw_type = DeviceType::CPU; diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index bedb600481f..40d22936111 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -9,13 +9,14 @@ #include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "torch_xla/csrc/lowering_context.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace ir { namespace { using ShapeCache = - xla::util::Cache; + xla::util::Cache; struct ScopeEntry { std::string name; @@ -101,7 +102,7 @@ std::string Use::ToString() const { } size_t Output::Hasher::operator()(const Output& output) const { - return xla::util::StdHashCombine( + return torch::lazy::StdHashCombine( reinterpret_cast(output.node), output.index); } @@ -109,8 +110,8 @@ const xla::Shape& Output::shape() const { return node->shape(index); } const xla::Shape& Output::node_shape() const { return node->shape(); } -xla::hash_t Output::hash() const { - return xla::util::HashCombine(node->hash(), index); +torch::lazy::hash_t Output::hash() const { + return torch::lazy::HashCombine(node->hash(), index); } std::string Output::ToString() const { @@ -123,36 +124,36 @@ const xla::Shape& Value::shape() const { return node->shape(index); } const xla::Shape& Value::node_shape() const { return node->shape(); } -xla::hash_t Value::hash() const { - return xla::util::HashCombine(node->hash(), index); +torch::lazy::hash_t Value::hash() const { + return torch::lazy::HashCombine(node->hash(), index); } OpKind OpKind::Get(const std::string& name) { return OpKind(c10::Symbol::fromQualString(name)); } -xla::hash_t OpKind::hash() const { - return xla::util::StringHash(op.toQualString()); +torch::lazy::hash_t OpKind::hash() const { + return torch::lazy::StringHash(op.toQualString()); } Node::Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs, - xla::hash_t hash_seed) + torch::lazy::hash_t hash_seed) : op_(std::move(op)), num_outputs_(num_outputs), shape_(std::move(shape)), - node_hash_(xla::util::HashCombine(op_.hash(), hash_seed)), + node_hash_(torch::lazy::HashCombine(op_.hash(), hash_seed)), hash_(node_hash_) { metadata_.scope = GetCurrentScope(); metadata_.frame_info = GetFrameInfo(); for (auto& operand : operands) { AddOperand(operand.node, operand.index); - hash_ = xla::util::HashCombine(hash_, operand.hash()); + hash_ = torch::lazy::HashCombine(hash_, operand.hash()); } } Node::Node(OpKind op, OpList operands, const std::function& shape_fn, size_t num_outputs, - xla::hash_t hash_seed) + torch::lazy::hash_t hash_seed) : Node(std::move(op), operands, xla::Shape(), num_outputs, hash_seed) { // Forward the constructor to the one above (with empty shape), so we have the // full hash information, then fetch/compute the real shape. @@ -160,7 +161,7 @@ Node::Node(OpKind op, OpList operands, } Node::Node(OpKind op, xla::Shape shape, size_t num_outputs, - xla::hash_t hash_seed) + torch::lazy::hash_t hash_seed) : op_(std::move(op)), num_outputs_(num_outputs), shape_(std::move(shape)), @@ -247,11 +248,11 @@ XlaOpVector Node::Lower(LoweringContext* loctx) const { XLA_ERROR() << "Lowering not implemented for node: " << *this; } -xla::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape, - xla::hash_t hash_seed) { - xla::hash_t h = - xla::util::HashCombine(op.hash(), xla::util::Hash(shape.ToString())); - return xla::util::HashCombine(h, hash_seed); +torch::lazy::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape, + torch::lazy::hash_t hash_seed) { + torch::lazy::hash_t h = + torch::lazy::HashCombine(op.hash(), torch::lazy::Hash(shape.ToString())); + return torch::lazy::HashCombine(h, hash_seed); } xla::Shape Node::GetOpShape(const std::function& shape_fn) const { diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 0a9461d8cc8..e0e26cdc461 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -17,6 +17,7 @@ #include "tensorflow/compiler/xla/xla_client/types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "torch_xla/csrc/python_util.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace ir { @@ -83,7 +84,7 @@ struct Output { const xla::Shape& shape() const; const xla::Shape& node_shape() const; - xla::hash_t hash() const; + torch::lazy::hash_t hash() const; bool operator==(const Output& rhs) const { return node == rhs.node && index == rhs.index; @@ -120,7 +121,7 @@ struct Value { const xla::Shape& shape() const; const xla::Shape& node_shape() const; - xla::hash_t hash() const; + torch::lazy::hash_t hash() const; operator bool() const { return node != nullptr; } @@ -143,7 +144,7 @@ struct OpKind { return c10::unique_t(op) < c10::unique_t(rhs.op); } - xla::hash_t hash() const; + torch::lazy::hash_t hash() const; std::string ToString() const { return op.toQualString(); } @@ -174,15 +175,15 @@ class Node { // for the operation. The num_outputs tells how many outputs a given operation // generates. Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs = 1, - xla::hash_t hash_seed = 0x5a2d296e9); + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); // Same as the constructor above, but the shape is generated by a function, // only if needed (shape cache miss). Node(OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs = 1, xla::hash_t hash_seed = 0x5a2d296e9); + size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); // Contructor used to create leaf nodes. - Node(OpKind op, xla::Shape shape, size_t num_outputs, xla::hash_t hash_seed); + Node(OpKind op, xla::Shape shape, size_t num_outputs, torch::lazy::hash_t hash_seed); virtual ~Node(); @@ -204,9 +205,9 @@ class Node { const std::set& uses() const { return uses_; } - xla::hash_t node_hash() const { return node_hash_; } + torch::lazy::hash_t node_hash() const { return node_hash_; } - xla::hash_t hash() const { return hash_; } + torch::lazy::hash_t hash() const { return hash_; } const MetaData& metadata() const { return metadata_; } @@ -243,8 +244,8 @@ class Node { xla::Shape GetOpShape(const std::function& shape_fn) const; - static xla::hash_t GetOpHash(OpKind op, const xla::Shape& shape, - xla::hash_t hash_seed); + static torch::lazy::hash_t GetOpHash(OpKind op, const xla::Shape& shape, + torch::lazy::hash_t hash_seed); static std::vector GetFrameInfo(); @@ -260,9 +261,9 @@ class Node { // We use a set for uses, as we want deterministic use sequencing. std::set uses_; // The hash value of this node. - xla::hash_t node_hash_ = 0; + torch::lazy::hash_t node_hash_ = 0; // The hash value of the graph rooted at this node. - xla::hash_t hash_ = 0; + torch::lazy::hash_t hash_ = 0; // The IR specific metadata attached to the IR node. MetaData metadata_; // The IR framework user can attach a user defined metadata object deriving diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp index 6a2f160e6af..36621a7116f 100644 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ b/torch_xla/csrc/op_by_op_executor.cpp @@ -10,11 +10,13 @@ #include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "tensorflow/compiler/xla/xla_client/xla_util.h" +#include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir_util.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/device_data.h" #include "torch_xla/csrc/tensor_util.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { namespace { @@ -41,17 +43,17 @@ const xla::Shape& GetParameterShape(const ir::Output& operand, : xla::ShapeUtil::GetTupleElementShape(input_shape, operand.index); } -xla::hash_t ComputeNodeKey(const ir::Node* node, - absl::Span input_shapes, - const xla::hash_t& seed) { - xla::hash_t key = seed; +torch::lazy::hash_t ComputeNodeKey( + const ir::Node* node, absl::Span input_shapes, + const torch::lazy::hash_t& seed) { + torch::lazy::hash_t key = seed; const auto& operands = node->operands(); for (size_t i = 0; i < operands.size(); ++i) { - key = xla::util::HashCombine(key, xla::util::ShapeHash(GetParameterShape( + key = torch::lazy::HashCombine(key, torch::lazy::Hash(GetParameterShape( operands[i], *input_shapes[i]))); } - key = xla::util::HashCombine(key, xla::util::ShapeHash(node->shape())); - return xla::util::HashCombine(key, node->node_hash()); + key = torch::lazy::HashCombine(key, torch::lazy::Hash(node->shape())); + return torch::lazy::HashCombine(key, node->node_hash()); } xla::XlaComputation BuildNodeComputation( @@ -71,9 +73,9 @@ xla::XlaComputation BuildNodeComputation( return ConsumeValue(loctx.Build()); } -xla::hash_t GetNodesKeySeed(const std::string& device, - absl::Span devices) { - return xla::util::MHash(device, devices); +torch::lazy::hash_t GetNodesKeySeed(const std::string& device, + absl::Span devices) { + return torch::lazy::MHash(device, torch::lazy::Hash(devices)); } } // namespace @@ -102,12 +104,14 @@ std::vector OpByOpExecutor::BuildOps( auto compilation_devices = xla::ComputationClient::Get()->GetCompilationDevices(device, devices); - xla::hash_t nodes_key_seed = GetNodesKeySeed(device, compilation_devices); + torch::lazy::hash_t nodes_key_seed = + GetNodesKeySeed(device, compilation_devices); Device exec_device(device); - std::vector cache_keys; - std::unordered_map, xla::util::HashReducer> + std::vector cache_keys; + std::unordered_map, + torch::lazy::HashReducer> compile_indices; - std::unordered_map + std::unordered_map cache_keys_instance; std::list compile_shapes; std::vector device_data_ops(post_order.size()); @@ -133,7 +137,7 @@ std::vector OpByOpExecutor::BuildOps( op_input_shapes.push_back(ops_shapes[op_index]); } - xla::hash_t cache_key = + torch::lazy::hash_t cache_key = ComputeNodeKey(node, op_input_shapes, nodes_key_seed); cxop.computation = compile_cache_.Get(cache_key); if (cxop.computation == nullptr) { diff --git a/torch_xla/csrc/op_by_op_executor.h b/torch_xla/csrc/op_by_op_executor.h index 563e271b54c..54f964f1fce 100644 --- a/torch_xla/csrc/op_by_op_executor.h +++ b/torch_xla/csrc/op_by_op_executor.h @@ -38,8 +38,8 @@ class OpByOpExecutor { private: using CompileCache = - xla::util::Cache; + xla::util::Cache; explicit OpByOpExecutor(size_t compile_cache_size); diff --git a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp index 7693e25b4e6..5b2795dbe7d 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp @@ -27,7 +27,7 @@ AdaptiveAvgPool2d::AdaptiveAvgPool2d(const Value& input, std::vector output_size) : Node(ir::OpKind(at::aten::adaptive_avg_pool2d), {input}, [&]() { return NodeOutputShape(input, output_size); }, - /*num_outputs=*/1, xla::util::MHash(output_size)), + /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} NodePtr AdaptiveAvgPool2d::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp index 54740192d4f..6c311ac1cf2 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp @@ -27,7 +27,7 @@ AdaptiveAvgPool3d::AdaptiveAvgPool3d(const Value& input, std::vector output_size) : Node(ir::OpKind(at::aten::adaptive_avg_pool3d), {input}, [&]() { return NodeOutputShape(input, output_size); }, - /*num_outputs=*/1, xla::util::MHash(output_size)), + /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} NodePtr AdaptiveAvgPool3d::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp index 09baec633a1..e8f76e36aca 100644 --- a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp @@ -28,7 +28,7 @@ AdaptiveMaxPool2d::AdaptiveMaxPool2d(const Value& input, std::vector output_size) : Node(ir::OpKind(at::aten::adaptive_max_pool2d), {input}, [&]() { return NodeOutputShape(input, output_size); }, - /*num_outputs=*/2, xla::util::MHash(output_size)), + /*num_outputs=*/2, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} NodePtr AdaptiveMaxPool2d::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/all.cpp b/torch_xla/csrc/ops/all.cpp index 7238a86d1c9..eb009c29a99 100644 --- a/torch_xla/csrc/ops/all.cpp +++ b/torch_xla/csrc/ops/all.cpp @@ -31,7 +31,7 @@ All::All(const Value& input, std::vector dimensions, : Node(ir::OpKind(at::aten::all), {input}, NodeOutputShape(input, dimensions, keep_reduced_dimensions), /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions)), + torch::lazy::MHash(dimensions, keep_reduced_dimensions)), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions) {} diff --git a/torch_xla/csrc/ops/all_reduce.cpp b/torch_xla/csrc/ops/all_reduce.cpp index 9795e7ca50c..c83c63b2c0d 100644 --- a/torch_xla/csrc/ops/all_reduce.cpp +++ b/torch_xla/csrc/ops/all_reduce.cpp @@ -37,7 +37,7 @@ AllReduce::AllReduce(AllReduceType reduce_type, : Node(xla_cross_replica_sum, GetOperandList(operands, token), [&]() { return NodeOutputShape(operands, token); }, /*num_outputs=*/operands.size() + 1, - xla::util::MHash(xla::util::GetEnumValue(reduce_type), scale, + torch::lazy::MHash(xla::util::GetEnumValue(reduce_type), scale, groups)), reduce_type_(reduce_type), scale_(scale), diff --git a/torch_xla/csrc/ops/all_to_all.cpp b/torch_xla/csrc/ops/all_to_all.cpp index 1648ccd379b..2c4f92d4a8d 100644 --- a/torch_xla/csrc/ops/all_to_all.cpp +++ b/torch_xla/csrc/ops/all_to_all.cpp @@ -37,7 +37,7 @@ AllToAll::AllToAll(const Value& input, const Value& token, concat_dimension, split_count, groups); }, /*num_outputs=*/2, - xla::util::MHash(split_dimension, concat_dimension, split_count, + torch::lazy::MHash(split_dimension, concat_dimension, split_count, groups)), split_dimension_(split_dimension), concat_dimension_(concat_dimension), diff --git a/torch_xla/csrc/ops/amax.cpp b/torch_xla/csrc/ops/amax.cpp index 6cb1fbdc2e5..cda79db570e 100644 --- a/torch_xla/csrc/ops/amax.cpp +++ b/torch_xla/csrc/ops/amax.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input, Amax::Amax(const Value& input, std::vector dimensions, bool keepdim) : Node(ir::OpKind(at::aten::amax), {input}, [&]() { return NodeOutputShape(input, dimensions, keepdim); }, - /*num_outputs=*/1, xla::util::MHash(dimensions, keepdim)), + /*num_outputs=*/1, torch::lazy::MHash(dimensions, keepdim)), dimensions_(std::move(dimensions)), keepdim_(keepdim) {} diff --git a/torch_xla/csrc/ops/amin.cpp b/torch_xla/csrc/ops/amin.cpp index 6560abe983d..51350ea54e7 100644 --- a/torch_xla/csrc/ops/amin.cpp +++ b/torch_xla/csrc/ops/amin.cpp @@ -24,7 +24,7 @@ xla::Shape NodeOutputShape(const Value& input, Amin::Amin(const Value& input, std::vector dimensions, bool keepdim) : Node(ir::OpKind(at::aten::amin), {input}, [&]() { return NodeOutputShape(input, dimensions, keepdim); }, - /*num_outputs=*/1, xla::util::MHash(dimensions, keepdim)), + /*num_outputs=*/1, torch::lazy::MHash(dimensions, keepdim)), dimensions_(std::move(dimensions)), keepdim_(keepdim) {} diff --git a/torch_xla/csrc/ops/any.cpp b/torch_xla/csrc/ops/any.cpp index 5d3d2069fd8..b5db12ade8b 100644 --- a/torch_xla/csrc/ops/any.cpp +++ b/torch_xla/csrc/ops/any.cpp @@ -31,7 +31,7 @@ Any::Any(const Value& input, std::vector dimensions, : Node(ir::OpKind(at::aten::any), {input}, NodeOutputShape(input, dimensions, keep_reduced_dimensions), /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions)), + torch::lazy::MHash(dimensions, keep_reduced_dimensions)), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions) {} diff --git a/torch_xla/csrc/ops/arg_max.cpp b/torch_xla/csrc/ops/arg_max.cpp index c26968c931c..56618008047 100644 --- a/torch_xla/csrc/ops/arg_max.cpp +++ b/torch_xla/csrc/ops/arg_max.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(const Value& input, xla::int64 dim, bool keepdim) { ArgMax::ArgMax(const Value& input, xla::int64 dim, bool keepdim) : Node(ir::OpKind(at::aten::argmax), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, - /*num_outputs=*/1, xla::util::MHash(dim, keepdim)), + /*num_outputs=*/1, torch::lazy::MHash(dim, keepdim)), dim_(dim), keepdim_(keepdim) {} diff --git a/torch_xla/csrc/ops/arg_min.cpp b/torch_xla/csrc/ops/arg_min.cpp index 4f5650027ed..512dc5c2be7 100644 --- a/torch_xla/csrc/ops/arg_min.cpp +++ b/torch_xla/csrc/ops/arg_min.cpp @@ -23,7 +23,7 @@ xla::Shape NodeOutputShape(const Value& input, xla::int64 dim, bool keepdim) { ArgMin::ArgMin(const Value& input, xla::int64 dim, bool keepdim) : Node(ir::OpKind(at::aten::argmin), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, - /*num_outputs=*/1, xla::util::MHash(dim, keepdim)), + /*num_outputs=*/1, torch::lazy::MHash(dim, keepdim)), dim_(dim), keepdim_(keepdim) {} diff --git a/torch_xla/csrc/ops/as_strided.cpp b/torch_xla/csrc/ops/as_strided.cpp index aa599ce52e0..e4ed625e7de 100644 --- a/torch_xla/csrc/ops/as_strided.cpp +++ b/torch_xla/csrc/ops/as_strided.cpp @@ -49,7 +49,7 @@ AsStrided::AsStrided(const Value& input, std::vector size, return xla::ShapeUtil::MakeShape(input.shape().element_type(), size); }, - /*num_outputs=*/1, xla::util::MHash(size, stride, storage_offset)), + /*num_outputs=*/1, torch::lazy::MHash(size, stride, storage_offset)), size_(std::move(size)), stride_(std::move(stride)), storage_offset_(storage_offset) {} diff --git a/torch_xla/csrc/ops/as_strided_view_update.cpp b/torch_xla/csrc/ops/as_strided_view_update.cpp index 54fbd2a8a02..6e39dde03b0 100644 --- a/torch_xla/csrc/ops/as_strided_view_update.cpp +++ b/torch_xla/csrc/ops/as_strided_view_update.cpp @@ -51,7 +51,7 @@ AsStridedViewUpdate::AsStridedViewUpdate(const Value& target, return xla::ShapeUtil::MakeShape(target.shape().element_type(), size); }, - /*num_outputs=*/1, xla::util::MHash(size, stride, storage_offset)), + /*num_outputs=*/1, torch::lazy::MHash(size, stride, storage_offset)), size_(std::move(size)), stride_(std::move(stride)), storage_offset_(storage_offset) {} diff --git a/torch_xla/csrc/ops/avg_pool_nd.cpp b/torch_xla/csrc/ops/avg_pool_nd.cpp index 191ce07e239..06fca993b8e 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd.cpp @@ -56,7 +56,7 @@ AvgPoolNd::AvgPoolNd(const Value& input, xla::int64 spatial_dim_count, count_include_pad); }, /*num_outputs=*/1, - xla::util::MHash(spatial_dim_count, kernel_size, stride, padding, + torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, ceil_mode, count_include_pad)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), diff --git a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp index 810edf843b1..dc3c65b4485 100644 --- a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp @@ -55,7 +55,7 @@ AvgPoolNdBackward::AvgPoolNdBackward( count_include_pad); }, /*num_outputs=*/1, - xla::util::MHash(spatial_dim_count, kernel_size, stride, padding, + torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, ceil_mode, count_include_pad)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), diff --git a/torch_xla/csrc/ops/binary_cross_entropy.cpp b/torch_xla/csrc/ops/binary_cross_entropy.cpp index 025d18d5605..213dabe3f93 100644 --- a/torch_xla/csrc/ops/binary_cross_entropy.cpp +++ b/torch_xla/csrc/ops/binary_cross_entropy.cpp @@ -38,7 +38,7 @@ BinaryCrossEntropy::BinaryCrossEntropy(const Value& logits, const Value& labels, xla::util::GetValuesVector({logits, labels}, {&weight}), [&]() { return NodeOutputShape(logits, labels, weight, reduction); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction))), + torch::lazy::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr BinaryCrossEntropy::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp b/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp index 58ac7b1c29e..cb95f47da15 100644 --- a/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp +++ b/torch_xla/csrc/ops/binary_cross_entropy_backward.cpp @@ -44,7 +44,7 @@ BinaryCrossEntropyBackward::BinaryCrossEntropyBackward( reduction); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction))), + torch::lazy::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr BinaryCrossEntropyBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index 84c08700a85..471a7f21c05 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, xla::PrimitiveType type) { Cast::Cast(const Value& input, xla::PrimitiveType type) : Node(xla_cast, {input}, NodeOutputShape(input, type), - /*num_outputs=*/1, xla::util::MHash(static_cast(type))), + /*num_outputs=*/1, torch::lazy::MHash(static_cast(type))), type_(type) {} Cast::Cast(const Value& input, at::ScalarType dtype, @@ -35,7 +35,7 @@ Cast::Cast(const Value& input, at::ScalarType dtype, NodeOutputShape(input, MakeXlaPrimitiveType(dtype, /*device=*/nullptr)), /*num_outputs=*/1, - xla::util::MHash(101, static_cast(dtype), + torch::lazy::MHash(101, static_cast(dtype), OptionalOr(stype, -1))), type_(MakeXlaPrimitiveType(dtype, /*device=*/nullptr)), dtype_(dtype), diff --git a/torch_xla/csrc/ops/cat.cpp b/torch_xla/csrc/ops/cat.cpp index 9fc7bf97eba..3b836202bc3 100644 --- a/torch_xla/csrc/ops/cat.cpp +++ b/torch_xla/csrc/ops/cat.cpp @@ -29,7 +29,7 @@ xla::Shape NodeOutputShape(absl::Span values, xla::int64 dim) { Cat::Cat(absl::Span values, xla::int64 dim) : Node(ir::OpKind(at::aten::cat), values, [&]() { return NodeOutputShape(values, dim); }, - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr Cat::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/cholesky.cpp b/torch_xla/csrc/ops/cholesky.cpp index 4da51579945..1d20082d4b0 100644 --- a/torch_xla/csrc/ops/cholesky.cpp +++ b/torch_xla/csrc/ops/cholesky.cpp @@ -11,7 +11,7 @@ namespace ops { Cholesky::Cholesky(const Value& input, bool lower) : Node(ir::OpKind(at::aten::cholesky), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(lower)), + /*num_outputs=*/1, torch::lazy::MHash(lower)), lower_(lower) {} NodePtr Cholesky::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/collective_permute.cpp b/torch_xla/csrc/ops/collective_permute.cpp index 4307f05d9ca..2eec2663d6f 100644 --- a/torch_xla/csrc/ops/collective_permute.cpp +++ b/torch_xla/csrc/ops/collective_permute.cpp @@ -30,7 +30,7 @@ CollectivePermute::CollectivePermute( std::vector> source_target_pairs) : Node(xla_collective_permute, {input, token}, [&]() { return NodeOutputShape(input, token, source_target_pairs); }, - /*num_outputs=*/2, xla::util::MHash(source_target_pairs)), + /*num_outputs=*/2, torch::lazy::MHash(source_target_pairs)), source_target_pairs_(std::move(source_target_pairs)) {} NodePtr CollectivePermute::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/constant_pad_nd.cpp b/torch_xla/csrc/ops/constant_pad_nd.cpp index 2d9f34623ad..72364ac7dec 100644 --- a/torch_xla/csrc/ops/constant_pad_nd.cpp +++ b/torch_xla/csrc/ops/constant_pad_nd.cpp @@ -38,7 +38,7 @@ ConstantPadNd::ConstantPadNd(const Value& input, std::vector pad, const at::Scalar& value) : Node(ir::OpKind(at::aten::constant_pad_nd), {input}, [&]() { return NodeOutputShape(input, value, pad); }, - /*num_outputs=*/1, xla::util::MHash(pad, ScalarHash(value))), + /*num_outputs=*/1, torch::lazy::MHash(pad, ScalarHash(value))), pad_(std::move(pad)), value_(value) {} diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp index d1bba60849b..118aee1224c 100644 --- a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp @@ -47,7 +47,7 @@ ConvolutionBackwardOverrideable::ConvolutionBackwardOverrideable( groups); }, /*num_outputs=*/3, - xla::util::MHash(stride, padding, dilation, transposed, + torch::lazy::MHash(stride, padding, dilation, transposed, output_padding, groups)), stride_(std::move(stride)), padding_(std::move(padding)), diff --git a/torch_xla/csrc/ops/convolution_overrideable.cpp b/torch_xla/csrc/ops/convolution_overrideable.cpp index 2799a5e11c4..6581c3b5fe6 100644 --- a/torch_xla/csrc/ops/convolution_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_overrideable.cpp @@ -45,7 +45,7 @@ ConvolutionOverrideable::ConvolutionOverrideable( transposed, output_padding, groups); }, /*num_outputs=*/1, - xla::util::MHash(stride, padding, dilation, transposed, + torch::lazy::MHash(stride, padding, dilation, transposed, output_padding, groups)), stride_(std::move(stride)), padding_(std::move(padding)), @@ -64,7 +64,7 @@ ConvolutionOverrideable::ConvolutionOverrideable( transposed, output_padding, groups); }, /*num_outputs=*/1, - xla::util::MHash(stride, padding, dilation, transposed, + torch::lazy::MHash(stride, padding, dilation, transposed, output_padding, groups)), stride_(std::move(stride)), padding_(std::move(padding)), diff --git a/torch_xla/csrc/ops/cumprod.cpp b/torch_xla/csrc/ops/cumprod.cpp index 12d706de8d4..fd06004a1e3 100644 --- a/torch_xla/csrc/ops/cumprod.cpp +++ b/torch_xla/csrc/ops/cumprod.cpp @@ -41,7 +41,7 @@ CumProd::CumProd(const Value& input, xla::int64 dim, : Node(ir::OpKind(at::aten::cumprod), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, - xla::util::MHash(dim, OptionalOr(dtype, -1))), + torch::lazy::MHash(dim, OptionalOr(dtype, -1))), dim_(dim), dtype_(dtype) {} diff --git a/torch_xla/csrc/ops/cumsum.cpp b/torch_xla/csrc/ops/cumsum.cpp index 911a04f9352..0e73e308e1b 100644 --- a/torch_xla/csrc/ops/cumsum.cpp +++ b/torch_xla/csrc/ops/cumsum.cpp @@ -40,7 +40,7 @@ CumSum::CumSum(const Value& input, xla::int64 dim, : Node(ir::OpKind(at::aten::cumsum), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, - xla::util::MHash(dim, OptionalOr(dtype, -1))), + torch::lazy::MHash(dim, OptionalOr(dtype, -1))), dim_(dim), dtype_(dtype) {} diff --git a/torch_xla/csrc/ops/device_data.cpp b/torch_xla/csrc/ops/device_data.cpp index a5a969fbe69..6cec8d5db5f 100644 --- a/torch_xla/csrc/ops/device_data.cpp +++ b/torch_xla/csrc/ops/device_data.cpp @@ -11,7 +11,7 @@ namespace ops { DeviceData::DeviceData(std::shared_ptr data) : Node(xla_device_data, data->shape(), /*num_outputs=*/1, - /*hash_seed=*/101), + /*hash_seed=*/(uint32_t)101), data_(std::move(data)) {} std::string DeviceData::ToString() const { diff --git a/torch_xla/csrc/ops/diagonal.cpp b/torch_xla/csrc/ops/diagonal.cpp index 2123d5f901d..5b43f5abf6c 100644 --- a/torch_xla/csrc/ops/diagonal.cpp +++ b/torch_xla/csrc/ops/diagonal.cpp @@ -18,7 +18,7 @@ Diagonal::Diagonal(const Value& input, xla::int64 offset, xla::int64 dim1, [&]() { return MakeDiagonalShape(input.shape(), offset, dim1, dim2); }, - /*num_outputs=*/1, xla::util::MHash(offset, dim1, dim2)), + /*num_outputs=*/1, torch::lazy::MHash(offset, dim1, dim2)), offset_(offset), dim1_(dim1), dim2_(dim2) {} diff --git a/torch_xla/csrc/ops/diagonal_view_update.cpp b/torch_xla/csrc/ops/diagonal_view_update.cpp index 18d72b9f6e5..219a3953f20 100644 --- a/torch_xla/csrc/ops/diagonal_view_update.cpp +++ b/torch_xla/csrc/ops/diagonal_view_update.cpp @@ -13,7 +13,7 @@ DiagonalViewUpdate::DiagonalViewUpdate(const Value& target, const Value& input, xla::int64 offset, xla::int64 dim1, xla::int64 dim2) : Node(xla_diagonal_view_update, {target, input}, target.shape(), - /*num_outputs=*/1, xla::util::MHash(offset, dim1, dim2)), + /*num_outputs=*/1, torch::lazy::MHash(offset, dim1, dim2)), offset_(offset), dim1_(dim1), dim2_(dim2) {} diff --git a/torch_xla/csrc/ops/discrete_uniform.cpp b/torch_xla/csrc/ops/discrete_uniform.cpp index 3953d90eff8..c74a16bc314 100644 --- a/torch_xla/csrc/ops/discrete_uniform.cpp +++ b/torch_xla/csrc/ops/discrete_uniform.cpp @@ -5,6 +5,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/random.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { namespace ir { @@ -13,7 +14,7 @@ namespace ops { DiscreteUniform::DiscreteUniform(const Value& from, const Value& to, const Value& seed, const xla::Shape& rng_shape) : Node(ir::OpKind(at::aten::random), {from, to, seed}, rng_shape, - /*num_outputs=*/1, xla::util::ShapeHash(rng_shape)) {} + /*num_outputs=*/1, torch::lazy::Hash(rng_shape)) {} NodePtr DiscreteUniform::Clone(OpList operands) const { return MakeNode(operands.at(0), operands.at(1), diff --git a/torch_xla/csrc/ops/expand.cpp b/torch_xla/csrc/ops/expand.cpp index 17dde9c9597..ab63e16bdc6 100644 --- a/torch_xla/csrc/ops/expand.cpp +++ b/torch_xla/csrc/ops/expand.cpp @@ -25,7 +25,7 @@ xla::Shape NodeOutputShape(const Value& input, Expand::Expand(const Value& input, std::vector size) : Node(ir::OpKind(at::aten::expand), {input}, [&]() { return NodeOutputShape(input, size); }, - /*num_outputs=*/1, xla::util::MHash(size)), + /*num_outputs=*/1, torch::lazy::MHash(size)), size_(std::move(size)) {} NodePtr Expand::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/flip.cpp b/torch_xla/csrc/ops/flip.cpp index ede41a036d5..29d7b0c31d3 100644 --- a/torch_xla/csrc/ops/flip.cpp +++ b/torch_xla/csrc/ops/flip.cpp @@ -10,7 +10,7 @@ namespace ops { Flip::Flip(const Value& input, std::vector dims) : Node(ir::OpKind(at::aten::flip), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(dims)), + /*num_outputs=*/1, torch::lazy::MHash(dims)), dims_(std::move(dims)) {} NodePtr Flip::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/gather.cpp b/torch_xla/csrc/ops/gather.cpp index 3868e1ca2e7..8b6119853e7 100644 --- a/torch_xla/csrc/ops/gather.cpp +++ b/torch_xla/csrc/ops/gather.cpp @@ -27,7 +27,7 @@ xla::Shape NodeOutputShape(const Value& input, const Value& index, Gather::Gather(const Value& input, xla::int64 dim, const Value& index) : Node(ir::OpKind(at::aten::gather), {input, index}, [&]() { return NodeOutputShape(input, index, dim); }, - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr Gather::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/generic.cpp b/torch_xla/csrc/ops/generic.cpp index 9b8f9526f20..85fa5532576 100644 --- a/torch_xla/csrc/ops/generic.cpp +++ b/torch_xla/csrc/ops/generic.cpp @@ -7,20 +7,20 @@ namespace ir { namespace ops { Generic::Generic(OpKind op, absl::Span operands, xla::Shape shape, - LowerFn lower_fn, size_t num_outputs, xla::hash_t hash_seed) + LowerFn lower_fn, size_t num_outputs, torch::lazy::hash_t hash_seed) : Node(std::move(op), operands, std::move(shape), num_outputs, hash_seed), lower_fn_(std::move(lower_fn)), hash_seed_(hash_seed) {} Generic::Generic(OpKind op, absl::Span operands, const std::function& shape_fn, LowerFn lower_fn, - size_t num_outputs, xla::hash_t hash_seed) + size_t num_outputs, torch::lazy::hash_t hash_seed) : Node(std::move(op), operands, shape_fn, num_outputs, hash_seed), lower_fn_(std::move(lower_fn)), hash_seed_(hash_seed) {} Generic::Generic(OpKind op, xla::Shape shape, LowerFn lower_fn, - size_t num_outputs, xla::hash_t hash_seed) + size_t num_outputs, torch::lazy::hash_t hash_seed) : Node(std::move(op), std::move(shape), num_outputs, hash_seed), lower_fn_(std::move(lower_fn)), hash_seed_(hash_seed) {} diff --git a/torch_xla/csrc/ops/generic.h b/torch_xla/csrc/ops/generic.h index 277d6f362c0..394fd86aebc 100644 --- a/torch_xla/csrc/ops/generic.h +++ b/torch_xla/csrc/ops/generic.h @@ -17,14 +17,14 @@ class Generic : public Node { Generic(OpKind op, absl::Span operands, xla::Shape shape, LowerFn lower_fn, size_t num_outputs = 1, - xla::hash_t hash_seed = 0x5a2d296e9); + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); Generic(OpKind op, absl::Span operands, const std::function& shape_fn, LowerFn lower_fn, - size_t num_outputs = 1, xla::hash_t hash_seed = 0x5a2d296e9); + size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); Generic(OpKind op, xla::Shape shape, LowerFn lower_fn, size_t num_outputs, - xla::hash_t hash_seed); + torch::lazy::hash_t hash_seed); NodePtr Clone(OpList operands) const override; @@ -32,7 +32,7 @@ class Generic : public Node { private: LowerFn lower_fn_; - xla::hash_t hash_seed_; + torch::lazy::hash_t hash_seed_; }; } // namespace ops diff --git a/torch_xla/csrc/ops/generic_slice.cpp b/torch_xla/csrc/ops/generic_slice.cpp index 09610e85c7e..b27e39aedd1 100644 --- a/torch_xla/csrc/ops/generic_slice.cpp +++ b/torch_xla/csrc/ops/generic_slice.cpp @@ -6,6 +6,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { namespace ir { @@ -29,7 +30,7 @@ GenericSlice::GenericSlice(const Value& input, absl::Span sizes) : Node(xla_generic_slice, {input}, [&]() { return NodeOutputShape(input, base_indices, sizes); }, - /*num_outputs=*/1, xla::util::MHash(base_indices, sizes)), + /*num_outputs=*/1, torch::lazy::MHash(torch::lazy::Hash(base_indices), torch::lazy::Hash(sizes))), base_indices_(base_indices.begin(), base_indices.end()), sizes_(sizes.begin(), sizes.end()) {} diff --git a/torch_xla/csrc/ops/get_dimensions_size.cpp b/torch_xla/csrc/ops/get_dimensions_size.cpp index de37d149b0d..29ae72f21fa 100644 --- a/torch_xla/csrc/ops/get_dimensions_size.cpp +++ b/torch_xla/csrc/ops/get_dimensions_size.cpp @@ -17,7 +17,7 @@ GetDimensionsSize::GetDimensionsSize(const Value& input, : Node(xla_get_dimensions_size, {input}, xla::ShapeUtil::MakeShape(GetShapeDimensionType(/*device=*/nullptr), {}), - /*num_outputs=*/1, xla::util::MHash(dimensions)), + /*num_outputs=*/1, torch::lazy::MHash(dimensions)), dimensions_(std::move(dimensions)) {} NodePtr GetDimensionsSize::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/hardtanh_backward.cpp b/torch_xla/csrc/ops/hardtanh_backward.cpp index fcbdbe1fa50..a7de77ac8e6 100644 --- a/torch_xla/csrc/ops/hardtanh_backward.cpp +++ b/torch_xla/csrc/ops/hardtanh_backward.cpp @@ -14,7 +14,7 @@ HardtanhBackward::HardtanhBackward(const Value& grad_output, const Value& input, const at::Scalar& max_val) : Node(OpKind(at::aten::hardtanh_backward), {grad_output, input}, grad_output.shape(), /*num_outputs=*/1, - xla::util::MHash(ScalarHash(min_val), ScalarHash(max_val))), + torch::lazy::MHash(ScalarHash(min_val), ScalarHash(max_val))), min_val_(min_val), max_val_(max_val) {} diff --git a/torch_xla/csrc/ops/index_get.cpp b/torch_xla/csrc/ops/index_get.cpp index 073a158c581..fb4df1b096c 100644 --- a/torch_xla/csrc/ops/index_get.cpp +++ b/torch_xla/csrc/ops/index_get.cpp @@ -27,7 +27,7 @@ IndexGet::IndexGet(const ir::Value& base, const ir::Value& indices, xla::int64 start_dim) : Node(OpKind(at::aten::index), {base, indices}, [&]() { return NodeOutputShape(base, indices, start_dim); }, - /*num_outputs=*/1, xla::util::MHash(start_dim)), + /*num_outputs=*/1, torch::lazy::MHash(start_dim)), start_dim_(start_dim) {} std::string IndexGet::ToString() const { diff --git a/torch_xla/csrc/ops/index_ops.cpp b/torch_xla/csrc/ops/index_ops.cpp index 1ed3eee07e9..eaf8d655f31 100644 --- a/torch_xla/csrc/ops/index_ops.cpp +++ b/torch_xla/csrc/ops/index_ops.cpp @@ -170,7 +170,7 @@ ir::NodePtr IndexFillOp(const ir::Value& buffer, xla::int64 dim, {buffer.shape(), index_rank1.shape(), value.shape()}, lower_for_shape_fn); }, - std::move(lower_fn), /*num_outputs=*/1, xla::util::MHash(dim)); + std::move(lower_fn), /*num_outputs=*/1, torch::lazy::MHash(dim)); } ir::NodePtr IndexAddOp(const ir::Value& buffer, xla::int64 dim, diff --git a/torch_xla/csrc/ops/index_put.cpp b/torch_xla/csrc/ops/index_put.cpp index eb852cda8bd..d0cbad3a6cc 100644 --- a/torch_xla/csrc/ops/index_put.cpp +++ b/torch_xla/csrc/ops/index_put.cpp @@ -12,7 +12,7 @@ IndexPut::IndexPut(const ir::Value& base, const ir::Value& indices, xla::int64 start_dim, const ir::Value& values, bool accumulate) : Node(OpKind(at::aten::index_put), {base, indices, values}, base.shape(), - /*num_outputs=*/1, xla::util::MHash(start_dim, accumulate)), + /*num_outputs=*/1, torch::lazy::MHash(start_dim, accumulate)), start_dim_(start_dim), accumulate_(accumulate) {} diff --git a/torch_xla/csrc/ops/index_select.cpp b/torch_xla/csrc/ops/index_select.cpp index d0e4979aa5e..9e2c95dc499 100644 --- a/torch_xla/csrc/ops/index_select.cpp +++ b/torch_xla/csrc/ops/index_select.cpp @@ -25,7 +25,7 @@ xla::Shape NodeOutputShape(const Value& input, const Value& index, IndexSelect::IndexSelect(const Value& input, xla::int64 dim, const Value& index) : Node(ir::OpKind(at::aten::index_select), {input, index}, [&]() { return NodeOutputShape(input, index, dim); }, - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr IndexSelect::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/kth_value.cpp b/torch_xla/csrc/ops/kth_value.cpp index 249b1b31cbe..8c7e4de4fd8 100644 --- a/torch_xla/csrc/ops/kth_value.cpp +++ b/torch_xla/csrc/ops/kth_value.cpp @@ -26,7 +26,7 @@ KthValue::KthValue(const Value& input, xla::int64 k, xla::int64 dim, bool keepdim) : Node(ir::OpKind(at::aten::kthvalue), {input}, [&]() { return NodeOutputShape(input, k, dim, keepdim); }, - /*num_outputs=*/2, xla::util::MHash(k, dim, keepdim)), + /*num_outputs=*/2, torch::lazy::MHash(k, dim, keepdim)), k_(k), dim_(dim), keepdim_(keepdim) {} diff --git a/torch_xla/csrc/ops/l1_loss.cpp b/torch_xla/csrc/ops/l1_loss.cpp index d0764b06cb9..781545e7b78 100644 --- a/torch_xla/csrc/ops/l1_loss.cpp +++ b/torch_xla/csrc/ops/l1_loss.cpp @@ -25,7 +25,7 @@ L1Loss::L1Loss(const Value& input, const Value& target, ReductionMode reduction) : Node(ir::OpKind(at::aten::l1_loss), {input, target}, [&]() { return NodeOutputShape(input, target, reduction); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction))), + torch::lazy::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr L1Loss::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/l1_loss_backward.cpp b/torch_xla/csrc/ops/l1_loss_backward.cpp index 8c1fe3c590d..39e7a38da6c 100644 --- a/torch_xla/csrc/ops/l1_loss_backward.cpp +++ b/torch_xla/csrc/ops/l1_loss_backward.cpp @@ -29,7 +29,7 @@ L1LossBackward::L1LossBackward(const Value& grad_output, const Value& input, return NodeOutputShape(grad_output, input, target, reduction); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction))), + torch::lazy::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr L1LossBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/leaky_relu.cpp b/torch_xla/csrc/ops/leaky_relu.cpp index 1e9ba6cfd00..0bcbababe14 100644 --- a/torch_xla/csrc/ops/leaky_relu.cpp +++ b/torch_xla/csrc/ops/leaky_relu.cpp @@ -10,7 +10,7 @@ namespace ops { LeakyRelu::LeakyRelu(const Value& input, double negative_slope) : Node(ir::OpKind(at::aten::leaky_relu), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(negative_slope)), + /*num_outputs=*/1, torch::lazy::MHash(negative_slope)), negative_slope_(negative_slope) {} NodePtr LeakyRelu::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/leaky_relu_backward.cpp b/torch_xla/csrc/ops/leaky_relu_backward.cpp index 8331943ea73..3103eecc46d 100644 --- a/torch_xla/csrc/ops/leaky_relu_backward.cpp +++ b/torch_xla/csrc/ops/leaky_relu_backward.cpp @@ -12,7 +12,7 @@ LeakyReluBackward::LeakyReluBackward(const Value& grad_output, const Value& input, double negative_slope) : Node(ir::OpKind(at::aten::leaky_relu_backward), {grad_output, input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(negative_slope)), + /*num_outputs=*/1, torch::lazy::MHash(negative_slope)), negative_slope_(negative_slope) {} NodePtr LeakyReluBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/linear_interpolation.cpp b/torch_xla/csrc/ops/linear_interpolation.cpp index f539b9cc645..bb98c63d70b 100644 --- a/torch_xla/csrc/ops/linear_interpolation.cpp +++ b/torch_xla/csrc/ops/linear_interpolation.cpp @@ -12,7 +12,7 @@ namespace ops { LinearInterpolation::LinearInterpolation(const Value& value, const Value& new_value, double alpha) : Node(xla_moving_average, {value, new_value}, value.shape(), - /*num_outputs=*/1, xla::util::MHash(alpha)), + /*num_outputs=*/1, torch::lazy::MHash(alpha)), alpha_(alpha) {} NodePtr LinearInterpolation::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/log_softmax.cpp b/torch_xla/csrc/ops/log_softmax.cpp index 4d49c1348ac..34fec5fbabe 100644 --- a/torch_xla/csrc/ops/log_softmax.cpp +++ b/torch_xla/csrc/ops/log_softmax.cpp @@ -34,7 +34,7 @@ LogSoftmax::LogSoftmax(const Value& input, xla::int64 dim, : Node(ir::OpKind(at::aten::log_softmax), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, - xla::util::MHash(dim, OptionalOr(dtype, -1))), + torch::lazy::MHash(dim, OptionalOr(dtype, -1))), dim_(dim), dtype_(dtype) {} diff --git a/torch_xla/csrc/ops/log_softmax_backward.cpp b/torch_xla/csrc/ops/log_softmax_backward.cpp index 41514212761..5039ed3143c 100644 --- a/torch_xla/csrc/ops/log_softmax_backward.cpp +++ b/torch_xla/csrc/ops/log_softmax_backward.cpp @@ -14,7 +14,7 @@ LogSoftmaxBackward::LogSoftmaxBackward(const Value& grad_output, const Value& output, xla::int64 dim) : Node(ir::OpKind(at::aten::_log_softmax_backward_data), {grad_output, output}, grad_output.shape(), - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr LogSoftmaxBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/logsumexp.cpp b/torch_xla/csrc/ops/logsumexp.cpp index 2d3edaf4fdc..300bf273719 100644 --- a/torch_xla/csrc/ops/logsumexp.cpp +++ b/torch_xla/csrc/ops/logsumexp.cpp @@ -33,7 +33,7 @@ Logsumexp::Logsumexp(const Value& input, std::vector dimensions, return NodeOutputShape(input, dimensions, keep_reduced_dimensions); }, /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions)), + torch::lazy::MHash(dimensions, keep_reduced_dimensions)), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions) {} diff --git a/torch_xla/csrc/ops/max_in_dim.cpp b/torch_xla/csrc/ops/max_in_dim.cpp index 313388b7de8..44e23b1dad5 100644 --- a/torch_xla/csrc/ops/max_in_dim.cpp +++ b/torch_xla/csrc/ops/max_in_dim.cpp @@ -25,7 +25,7 @@ xla::Shape NodeOutputShape(const Value& input, xla::int64 dim, bool keepdim) { MaxInDim::MaxInDim(const Value& input, xla::int64 dim, bool keepdim) : Node(ir::OpKind(at::aten::max), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, - /*num_outputs=*/2, xla::util::MHash(dim, keepdim)), + /*num_outputs=*/2, torch::lazy::MHash(dim, keepdim)), dim_(dim), keepdim_(keepdim) {} diff --git a/torch_xla/csrc/ops/max_pool_nd.cpp b/torch_xla/csrc/ops/max_pool_nd.cpp index 931c43c3f90..2d1046200f6 100644 --- a/torch_xla/csrc/ops/max_pool_nd.cpp +++ b/torch_xla/csrc/ops/max_pool_nd.cpp @@ -51,7 +51,7 @@ MaxPoolNd::MaxPoolNd(const Value& input, xla::int64 spatial_dim_count, stride, padding, ceil_mode); }, /*num_outputs=*/2, - xla::util::MHash(spatial_dim_count, kernel_size, stride, padding, + torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, ceil_mode)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), diff --git a/torch_xla/csrc/ops/max_pool_nd_backward.cpp b/torch_xla/csrc/ops/max_pool_nd_backward.cpp index a8ae5c48810..d8959146eea 100644 --- a/torch_xla/csrc/ops/max_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_pool_nd_backward.cpp @@ -53,7 +53,7 @@ MaxPoolNdBackward::MaxPoolNdBackward( kernel_size, stride, padding, ceil_mode); }, /*num_outputs=*/1, - xla::util::MHash(spatial_dim_count, kernel_size, stride, padding, + torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, ceil_mode)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), diff --git a/torch_xla/csrc/ops/max_unpool_nd.cpp b/torch_xla/csrc/ops/max_unpool_nd.cpp index 5ec814884cc..770656f3095 100644 --- a/torch_xla/csrc/ops/max_unpool_nd.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd.cpp @@ -38,7 +38,7 @@ MaxUnpoolNd::MaxUnpoolNd(const Value& input, const Value& indices, std::vector output_size) : Node(ir::OpKind(MaxUnpoolNdSymbol(output_size.size())), {input, indices}, [&]() { return NodeOutputShape(input, indices, output_size); }, - /*num_outputs=*/1, xla::util::MHash(output_size)), + /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} NodePtr MaxUnpoolNd::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp index c19d411e64a..79e530924b9 100644 --- a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp @@ -45,7 +45,7 @@ MaxUnpoolNdBackward::MaxUnpoolNdBackward(const Value& grad_output, [&]() { return NodeOutputShape(grad_output, input, indices, output_size); }, - /*num_outputs=*/1, xla::util::MHash(output_size)), + /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} NodePtr MaxUnpoolNdBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/mean.cpp b/torch_xla/csrc/ops/mean.cpp index 43c0aa0b59f..b11885aef5a 100644 --- a/torch_xla/csrc/ops/mean.cpp +++ b/torch_xla/csrc/ops/mean.cpp @@ -45,7 +45,7 @@ Mean::Mean(const Value& input, std::vector dimensions, dtype); }, /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions, + torch::lazy::MHash(dimensions, keep_reduced_dimensions, OptionalOr(dtype, -1))), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), diff --git a/torch_xla/csrc/ops/min_in_dim.cpp b/torch_xla/csrc/ops/min_in_dim.cpp index f1e29a1ae18..ec5e0ba24b9 100644 --- a/torch_xla/csrc/ops/min_in_dim.cpp +++ b/torch_xla/csrc/ops/min_in_dim.cpp @@ -25,7 +25,7 @@ xla::Shape NodeOutputShape(const Value& input, xla::int64 dim, bool keepdim) { MinInDim::MinInDim(const Value& input, xla::int64 dim, bool keepdim) : Node(ir::OpKind(at::aten::min), {input}, [&]() { return NodeOutputShape(input, dim, keepdim); }, - /*num_outputs=*/2, xla::util::MHash(dim, keepdim)), + /*num_outputs=*/2, torch::lazy::MHash(dim, keepdim)), dim_(dim), keepdim_(keepdim) {} diff --git a/torch_xla/csrc/ops/mse_loss.cpp b/torch_xla/csrc/ops/mse_loss.cpp index 4924f0233ca..08a438ee8c8 100644 --- a/torch_xla/csrc/ops/mse_loss.cpp +++ b/torch_xla/csrc/ops/mse_loss.cpp @@ -28,7 +28,7 @@ MseLoss::MseLoss(const Value& input, const Value& target, : Node(ir::OpKind(at::aten::mse_loss), {input, target}, [&]() { return NodeOutputShape(input, target, reduction); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction))), + torch::lazy::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr MseLoss::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/mse_loss_backward.cpp b/torch_xla/csrc/ops/mse_loss_backward.cpp index 5614aeb1467..4f34ee8ceb6 100644 --- a/torch_xla/csrc/ops/mse_loss_backward.cpp +++ b/torch_xla/csrc/ops/mse_loss_backward.cpp @@ -32,7 +32,7 @@ MseLossBackward::MseLossBackward(const Value& grad_output, const Value& input, return NodeOutputShape(grad_output, input, target, reduction); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction))), + torch::lazy::MHash(xla::util::GetEnumValue(reduction))), reduction_(reduction) {} NodePtr MseLossBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/native_batch_norm_backward.cpp b/torch_xla/csrc/ops/native_batch_norm_backward.cpp index 6e0120b01e9..d806f8cd7b5 100644 --- a/torch_xla/csrc/ops/native_batch_norm_backward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_backward.cpp @@ -39,7 +39,7 @@ NativeBatchNormBackward::NativeBatchNormBackward( return NodeOutputShape(grad_out, input, weight, save_mean, save_invstd, training); }, - /*num_outputs=*/3, xla::util::MHash(training, eps)), + /*num_outputs=*/3, torch::lazy::MHash(training, eps)), training_(training), eps_(eps) {} diff --git a/torch_xla/csrc/ops/native_batch_norm_forward.cpp b/torch_xla/csrc/ops/native_batch_norm_forward.cpp index 4bb16cff435..a7417589835 100644 --- a/torch_xla/csrc/ops/native_batch_norm_forward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_forward.cpp @@ -63,7 +63,7 @@ NativeBatchNormForward::NativeBatchNormForward(const Value& input, return NodeOutputShape(input, weight, bias, running_mean, running_var, training); }, - /*num_outputs=*/4, xla::util::MHash(training, eps)), + /*num_outputs=*/4, torch::lazy::MHash(training, eps)), training_(training), eps_(eps) {} diff --git a/torch_xla/csrc/ops/nll_loss.cpp b/torch_xla/csrc/ops/nll_loss.cpp index 49fe037724c..b25dc3f00ad 100644 --- a/torch_xla/csrc/ops/nll_loss.cpp +++ b/torch_xla/csrc/ops/nll_loss.cpp @@ -43,7 +43,7 @@ NllLoss::NllLoss(const Value& logits, const Value& labels, ignore_index); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/nll_loss2d.cpp b/torch_xla/csrc/ops/nll_loss2d.cpp index 199d3d43c4c..aec10e4949f 100644 --- a/torch_xla/csrc/ops/nll_loss2d.cpp +++ b/torch_xla/csrc/ops/nll_loss2d.cpp @@ -43,7 +43,7 @@ NllLoss2d::NllLoss2d(const Value& logits, const Value& labels, ignore_index); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/nll_loss2d_backward.cpp b/torch_xla/csrc/ops/nll_loss2d_backward.cpp index add32d50b61..110372fa9e1 100644 --- a/torch_xla/csrc/ops/nll_loss2d_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss2d_backward.cpp @@ -52,7 +52,7 @@ NllLoss2dBackward::NllLoss2dBackward(const Value& grad_output, total_weight, reduction, ignore_index); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/nll_loss_backward.cpp b/torch_xla/csrc/ops/nll_loss_backward.cpp index f2a51694ddc..b8f104e9d6f 100644 --- a/torch_xla/csrc/ops/nll_loss_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss_backward.cpp @@ -52,7 +52,7 @@ NllLossBackward::NllLossBackward(const Value& grad_output, const Value& logits, total_weight, reduction, ignore_index); }, /*num_outputs=*/1, - xla::util::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/nms.cpp b/torch_xla/csrc/ops/nms.cpp index dd8798c11d7..aed5f4f9179 100644 --- a/torch_xla/csrc/ops/nms.cpp +++ b/torch_xla/csrc/ops/nms.cpp @@ -35,7 +35,7 @@ Nms::Nms(const Value& boxes, const Value& scores, const Value& score_threshold, return NodeOutputShape(boxes, scores, score_threshold, iou_threshold, output_size); }, - /*num_outputs=*/2, xla::util::MHash(output_size)), + /*num_outputs=*/2, torch::lazy::MHash(output_size)), output_size_(output_size) {} NodePtr Nms::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/not_supported.cpp b/torch_xla/csrc/ops/not_supported.cpp index 5ea05a5b41e..02dfb41b6fc 100644 --- a/torch_xla/csrc/ops/not_supported.cpp +++ b/torch_xla/csrc/ops/not_supported.cpp @@ -11,7 +11,7 @@ namespace ops { NotSupported::NotSupported(std::string description, xla::Shape shape) : Node(xla_not_supported, std::move(shape), /*num_outputs=*/1, - xla::util::MHash(description)), + torch::lazy::MHash(description)), description_(std::move(description)) {} NodePtr NotSupported::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/ops.cpp b/torch_xla/csrc/ops/ops.cpp index 12c4338504a..4cac1c49a50 100644 --- a/torch_xla/csrc/ops/ops.cpp +++ b/torch_xla/csrc/ops/ops.cpp @@ -112,7 +112,7 @@ NodePtr LogBase(const Value& input, OpKind op, double base) { return node.ReturnOp(result * ln_base, loctx); }; return GenericOp(op, {input}, input.shape(), std::move(lower_fn), - /*num_outputs=*/1, xla::util::MHash(base)); + /*num_outputs=*/1, torch::lazy::MHash(base)); } NodePtr ReciprocalOp(const Value& input) { @@ -589,7 +589,7 @@ NodePtr Identity(xla::int64 lines, xla::int64 cols, return GenericOp(OpKind(at::aten::eye), xla::ShapeUtil::MakeShape(element_type, {lines, cols}), std::move(lower_fn), /*num_outputs=*/1, - xla::util::MHash(lines, cols)); + torch::lazy::MHash(lines, cols)); } NodePtr Elu(const Value& input, const at::Scalar& alpha, diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index fa5e428ef24..4772ecf9e27 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -28,7 +28,7 @@ inline NodePtr ConstantOp(xla::Literal value) { inline NodePtr GenericOp(OpKind op, absl::Span operands, xla::Shape shape, Generic::LowerFn lower_fn, size_t num_outputs = 1, - xla::hash_t hash_seed = 0x5a2d296e9) { + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { return MakeNode(std::move(op), operands, std::move(shape), std::move(lower_fn), num_outputs, hash_seed); } @@ -36,13 +36,13 @@ inline NodePtr GenericOp(OpKind op, absl::Span operands, inline NodePtr GenericOp(OpKind op, absl::Span operands, const std::function& shape_fn, Generic::LowerFn lower_fn, size_t num_outputs = 1, - xla::hash_t hash_seed = 0x5a2d296e9) { + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { return MakeNode(std::move(op), operands, shape_fn, std::move(lower_fn), num_outputs, hash_seed); } inline NodePtr GenericOp(OpKind op, xla::Shape shape, Generic::LowerFn lower_fn, - size_t num_outputs, xla::hash_t hash_seed) { + size_t num_outputs, torch::lazy::hash_t hash_seed) { return MakeNode(std::move(op), std::move(shape), std::move(lower_fn), num_outputs, hash_seed); } diff --git a/torch_xla/csrc/ops/permute.cpp b/torch_xla/csrc/ops/permute.cpp index 692b7fef520..8195dbd69c7 100644 --- a/torch_xla/csrc/ops/permute.cpp +++ b/torch_xla/csrc/ops/permute.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, Permute::Permute(const Value& input, std::vector dims) : Node(ir::OpKind(at::aten::permute), {input}, [&]() { return NodeOutputShape(input, dims); }, - /*num_outputs=*/1, xla::util::MHash(dims)), + /*num_outputs=*/1, torch::lazy::MHash(dims)), dims_(std::move(dims)) {} NodePtr Permute::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index 1ea4b9bcc1c..380357fd866 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -51,7 +51,7 @@ Prod::Prod(const Value& input, std::vector dimensions, dtype); }, /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions, + torch::lazy::MHash(dimensions, keep_reduced_dimensions, OptionalOr(dtype, -1))), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), diff --git a/torch_xla/csrc/ops/put.cpp b/torch_xla/csrc/ops/put.cpp index f81a9aa018c..31e6747cdc8 100644 --- a/torch_xla/csrc/ops/put.cpp +++ b/torch_xla/csrc/ops/put.cpp @@ -11,7 +11,7 @@ namespace ops { Put::Put(const Value& input, const Value& index, const Value& source, bool accumulate) : Node(ir::OpKind(at::aten::put), {input, index, source}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(accumulate)), + /*num_outputs=*/1, torch::lazy::MHash(accumulate)), accumulate_(accumulate) {} NodePtr Put::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/qr.cpp b/torch_xla/csrc/ops/qr.cpp index 13ed4d82dc7..1a887b96f73 100644 --- a/torch_xla/csrc/ops/qr.cpp +++ b/torch_xla/csrc/ops/qr.cpp @@ -45,7 +45,7 @@ xla::Shape NodeOutputShape(const Value& input, bool some) { QR::QR(const Value& input, bool some) : Node(ir::OpKind(at::aten::qr), {input}, [&]() { return NodeOutputShape(input, some); }, - /*num_outputs=*/2, xla::util::MHash(some)), + /*num_outputs=*/2, torch::lazy::MHash(some)), some_(some) {} NodePtr QR::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/reflection_pad2d.cpp b/torch_xla/csrc/ops/reflection_pad2d.cpp index 5788b43b317..134374e2b2c 100644 --- a/torch_xla/csrc/ops/reflection_pad2d.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d.cpp @@ -26,7 +26,7 @@ ReflectionPad2d::ReflectionPad2d(const Value& input, std::vector padding) : Node(OpKind(at::aten::reflection_pad2d), {input}, [&]() { return NodeOutputShape(input, padding); }, - /*num_outputs=*/1, xla::util::MHash(padding)), + /*num_outputs=*/1, torch::lazy::MHash(padding)), padding_(std::move(padding)) {} NodePtr ReflectionPad2d::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp index f9c04f46d0c..e5b36b1a4bc 100644 --- a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp @@ -28,7 +28,7 @@ ReflectionPad2dBackward::ReflectionPad2dBackward( std::vector padding) : Node(OpKind(at::aten::reflection_pad2d_backward), {grad_output, input}, [&]() { return NodeOutputShape(grad_output, input, padding); }, - /*num_outputs=*/1, xla::util::MHash(padding)), + /*num_outputs=*/1, torch::lazy::MHash(padding)), padding_(std::move(padding)) {} NodePtr ReflectionPad2dBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/repeat.cpp b/torch_xla/csrc/ops/repeat.cpp index b6af0896fb7..f1600e577b1 100644 --- a/torch_xla/csrc/ops/repeat.cpp +++ b/torch_xla/csrc/ops/repeat.cpp @@ -26,7 +26,7 @@ xla::Shape NodeOutputShape(const Value& input, Repeat::Repeat(const Value& input, std::vector repeats) : Node(ir::OpKind(at::aten::repeat), {input}, [&]() { return NodeOutputShape(input, repeats); }, - /*num_outputs=*/1, xla::util::MHash(repeats)), + /*num_outputs=*/1, torch::lazy::MHash(repeats)), repeats_(std::move(repeats)) {} NodePtr Repeat::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/replication_pad.cpp b/torch_xla/csrc/ops/replication_pad.cpp index a82e34c82a2..44872689321 100644 --- a/torch_xla/csrc/ops/replication_pad.cpp +++ b/torch_xla/csrc/ops/replication_pad.cpp @@ -26,7 +26,7 @@ ReplicationPad::ReplicationPad(const Value& input, std::vector padding) : Node(xla_replication_pad, {input}, [&]() { return NodeOutputShape(input, padding); }, - /*num_outputs=*/1, xla::util::MHash(padding)), + /*num_outputs=*/1, torch::lazy::MHash(padding)), padding_(std::move(padding)) {} NodePtr ReplicationPad::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/replication_pad_backward.cpp b/torch_xla/csrc/ops/replication_pad_backward.cpp index 50eff53d591..c86031f24ac 100644 --- a/torch_xla/csrc/ops/replication_pad_backward.cpp +++ b/torch_xla/csrc/ops/replication_pad_backward.cpp @@ -29,7 +29,7 @@ ReplicationPadBackward::ReplicationPadBackward(const Value& grad_output, std::vector padding) : Node(xla_replication_pad_backward, {grad_output, input}, [&]() { return NodeOutputShape(grad_output, input, padding); }, - /*num_outputs=*/1, xla::util::MHash(padding)), + /*num_outputs=*/1, torch::lazy::MHash(padding)), padding_(std::move(padding)) {} NodePtr ReplicationPadBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/resize.cpp b/torch_xla/csrc/ops/resize.cpp index 9a4c6512acc..e708fdbb150 100644 --- a/torch_xla/csrc/ops/resize.cpp +++ b/torch_xla/csrc/ops/resize.cpp @@ -21,7 +21,7 @@ xla::Shape NodeOutputShape(const Value& input, Resize::Resize(const Value& input, std::vector size) : Node(ir::OpKind(at::aten::resize), {input}, [&]() { return NodeOutputShape(input, size); }, - /*num_outputs=*/1, xla::util::MHash(size)), + /*num_outputs=*/1, torch::lazy::MHash(size)), size_(std::move(size)) {} NodePtr Resize::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/rrelu_with_noise.cpp b/torch_xla/csrc/ops/rrelu_with_noise.cpp index 66b7bb673c9..18d327de96b 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise.cpp @@ -16,7 +16,7 @@ RreluWithNoise::RreluWithNoise(const Value& input, const Value& seed, : Node(ir::OpKind(at::aten::rrelu_with_noise), {input, seed}, xla::ShapeUtil::MakeTupleShape({input.shape(), input.shape()}), /*num_outputs=*/2, - xla::util::MHash(ScalarHash(lower), ScalarHash(upper), training)), + torch::lazy::MHash(ScalarHash(lower), ScalarHash(upper), training)), lower_(std::move(lower)), upper_(std::move(upper)), training_(training) {} diff --git a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp index d941a9bd9f3..b66ad7bce0a 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp @@ -15,7 +15,7 @@ RreluWithNoiseBackward::RreluWithNoiseBackward( : Node(ir::OpKind(at::aten::rrelu_with_noise_backward), {grad_output, input, noise}, input.shape(), /*num_outputs=*/1, - xla::util::MHash(ScalarHash(lower), ScalarHash(upper), training)), + torch::lazy::MHash(ScalarHash(lower), ScalarHash(upper), training)), lower_(std::move(lower)), upper_(std::move(upper)), training_(training) {} diff --git a/torch_xla/csrc/ops/scalar.cpp b/torch_xla/csrc/ops/scalar.cpp index a0879eccb0d..bad4945502f 100644 --- a/torch_xla/csrc/ops/scalar.cpp +++ b/torch_xla/csrc/ops/scalar.cpp @@ -95,9 +95,9 @@ XlaOpVector Scalar::Lower(LoweringContext* loctx) const { return ReturnOp(op, loctx); } -xla::hash_t ScalarHash(const at::Scalar& s) { - return s.isFloatingPoint() ? xla::util::Hash(s.toDouble()) - : xla::util::Hash(s.toLong()); +torch::lazy::hash_t ScalarHash(const at::Scalar& s) { + return s.isFloatingPoint() ? torch::lazy::Hash(s.toDouble()) + : torch::lazy::Hash(s.toLong()); } std::ostream& operator<<(std::ostream& ostrm, at::Scalar s) { diff --git a/torch_xla/csrc/ops/scalar.h b/torch_xla/csrc/ops/scalar.h index fd6ec179eae..3a94c29a98b 100644 --- a/torch_xla/csrc/ops/scalar.h +++ b/torch_xla/csrc/ops/scalar.h @@ -32,7 +32,7 @@ class Scalar : public Node { at::Scalar value_; }; -xla::hash_t ScalarHash(const at::Scalar& s); +torch::lazy::hash_t ScalarHash(const at::Scalar& s); std::ostream& operator<<(std::ostream& ostrm, at::Scalar s); diff --git a/torch_xla/csrc/ops/scatter.cpp b/torch_xla/csrc/ops/scatter.cpp index af3b1e20a1e..acaf850bfbb 100644 --- a/torch_xla/csrc/ops/scatter.cpp +++ b/torch_xla/csrc/ops/scatter.cpp @@ -11,7 +11,7 @@ namespace ops { Scatter::Scatter(const Value& input, const Value& index, const Value& src, xla::int64 dim) : Node(ir::OpKind(at::aten::scatter), {input, index, src}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr Scatter::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/scatter_add.cpp b/torch_xla/csrc/ops/scatter_add.cpp index d8b6adde94f..a692c208bd4 100644 --- a/torch_xla/csrc/ops/scatter_add.cpp +++ b/torch_xla/csrc/ops/scatter_add.cpp @@ -14,7 +14,7 @@ ScatterAdd::ScatterAdd(const Value& input, const Value& index, const Value& src, xla::int64 dim) : Node(ir::OpKind(at::aten::scatter_add), {input, index, src}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr ScatterAdd::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/select.cpp b/torch_xla/csrc/ops/select.cpp index 58aa8bb0094..15236c210cb 100644 --- a/torch_xla/csrc/ops/select.cpp +++ b/torch_xla/csrc/ops/select.cpp @@ -15,7 +15,7 @@ Select::Select(const Value& input, xla::int64 dim, xla::int64 start, [&]() { return MakeSelectShape(input.shape(), dim, start, end, stride); }, - /*num_outputs=*/1, xla::util::MHash(dim, start, end, stride)), + /*num_outputs=*/1, torch::lazy::MHash(dim, start, end, stride)), dim_(dim), start_(start), end_(end), diff --git a/torch_xla/csrc/ops/softmax.cpp b/torch_xla/csrc/ops/softmax.cpp index 50d07ea2e46..a2453ce3c2c 100644 --- a/torch_xla/csrc/ops/softmax.cpp +++ b/torch_xla/csrc/ops/softmax.cpp @@ -34,7 +34,7 @@ Softmax::Softmax(const Value& input, xla::int64 dim, : Node(ir::OpKind(at::aten::softmax), {input}, [&]() { return NodeOutputShape(input, dtype); }, /*num_outputs=*/1, - xla::util::MHash(dim, OptionalOr(dtype, -1))), + torch::lazy::MHash(dim, OptionalOr(dtype, -1))), dim_(dim), dtype_(dtype) {} diff --git a/torch_xla/csrc/ops/softmax_backward.cpp b/torch_xla/csrc/ops/softmax_backward.cpp index 71308104ac4..37b8c75ca7e 100644 --- a/torch_xla/csrc/ops/softmax_backward.cpp +++ b/torch_xla/csrc/ops/softmax_backward.cpp @@ -14,7 +14,7 @@ SoftmaxBackward::SoftmaxBackward(const Value& grad_output, const Value& output, xla::int64 dim) : Node(ir::OpKind(at::aten::_softmax_backward_data), {grad_output, output}, grad_output.shape(), - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr SoftmaxBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/split.cpp b/torch_xla/csrc/ops/split.cpp index 3436c02207f..fc822a38353 100644 --- a/torch_xla/csrc/ops/split.cpp +++ b/torch_xla/csrc/ops/split.cpp @@ -30,7 +30,7 @@ Split::Split(const Value& input, std::vector split_sizes, : Node(ir::OpKind(at::aten::split), {input}, [&]() { return NodeOutputShape(input, split_sizes, dim); }, ComputeSplitCount(input.shape().dimensions(dim), split_sizes), - xla::util::MHash(split_sizes, dim)), + torch::lazy::MHash(split_sizes, dim)), split_sizes_(std::move(split_sizes)), dim_(dim) {} diff --git a/torch_xla/csrc/ops/squeeze.cpp b/torch_xla/csrc/ops/squeeze.cpp index f404163dd1f..0a53c1fe0a2 100644 --- a/torch_xla/csrc/ops/squeeze.cpp +++ b/torch_xla/csrc/ops/squeeze.cpp @@ -33,7 +33,7 @@ xla::Shape NodeOutputShape(const Value& input, int dim) { Squeeze::Squeeze(const Value& input, int dim) : Node(ir::OpKind(at::aten::squeeze), {input}, [&]() { return NodeOutputShape(input, dim); }, - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr Squeeze::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/stack.cpp b/torch_xla/csrc/ops/stack.cpp index f0f0ed894e3..b4ceb7078ec 100644 --- a/torch_xla/csrc/ops/stack.cpp +++ b/torch_xla/csrc/ops/stack.cpp @@ -29,7 +29,7 @@ xla::Shape NodeOutputShape(absl::Span values, xla::int64 dim) { Stack::Stack(absl::Span values, xla::int64 dim) : Node(ir::OpKind(at::aten::stack), values, [&]() { return NodeOutputShape(values, dim); }, - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr Stack::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/std.cpp b/torch_xla/csrc/ops/std.cpp index b6ec485a055..67cb1f2b447 100644 --- a/torch_xla/csrc/ops/std.cpp +++ b/torch_xla/csrc/ops/std.cpp @@ -33,7 +33,7 @@ Std::Std(const Value& input, std::vector dimensions, correction); }, /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions, correction)), + torch::lazy::MHash(dimensions, keep_reduced_dimensions, correction)), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), correction_(correction) {} diff --git a/torch_xla/csrc/ops/std_mean.cpp b/torch_xla/csrc/ops/std_mean.cpp index 30297f4718e..7b8774b2f19 100644 --- a/torch_xla/csrc/ops/std_mean.cpp +++ b/torch_xla/csrc/ops/std_mean.cpp @@ -36,7 +36,7 @@ StdMean::StdMean(const Value& input, std::vector dimensions, correction); }, /*num_outputs=*/2, - xla::util::MHash(dimensions, correction, keep_reduced_dimensions)), + torch::lazy::MHash(dimensions, correction, keep_reduced_dimensions)), dimensions_(std::move(dimensions)), correction_(correction), keep_reduced_dimensions_(keep_reduced_dimensions) {} diff --git a/torch_xla/csrc/ops/sum.cpp b/torch_xla/csrc/ops/sum.cpp index ca44adf1bb2..6edd1c5df17 100644 --- a/torch_xla/csrc/ops/sum.cpp +++ b/torch_xla/csrc/ops/sum.cpp @@ -43,7 +43,7 @@ Sum::Sum(const Value& input, std::vector dimensions, dtype); }, /*num_outputs=*/1, - xla::util::MHash(dimensions, keep_reduced_dimensions, + torch::lazy::MHash(dimensions, keep_reduced_dimensions, OptionalOr(dtype, -1))), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), diff --git a/torch_xla/csrc/ops/svd.cpp b/torch_xla/csrc/ops/svd.cpp index 849d2b98bcf..21c0f57becf 100644 --- a/torch_xla/csrc/ops/svd.cpp +++ b/torch_xla/csrc/ops/svd.cpp @@ -69,7 +69,7 @@ xla::Shape NodeOutputShape(const Value& input, bool some, bool compute_uv) { SVD::SVD(const Value& input, bool some, bool compute_uv) : Node(ir::OpKind(at::aten::svd), {input}, [&]() { return NodeOutputShape(input, some, compute_uv); }, - /*num_outputs=*/3, xla::util::MHash(some, compute_uv)), + /*num_outputs=*/3, torch::lazy::MHash(some, compute_uv)), some_(some), compute_uv_(compute_uv) {} diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index 6a6721bed72..a6622a4f361 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -49,7 +49,7 @@ xla::Shape NodeOutputShape(const Value& input, bool eigenvectors, bool lower) { SymEig::SymEig(const Value& input, bool eigenvectors, bool lower) : Node(ir::OpKind(at::aten::symeig), {input}, [&]() { return NodeOutputShape(input, eigenvectors, lower); }, - /*num_outputs=*/2, xla::util::MHash(eigenvectors, lower)), + /*num_outputs=*/2, torch::lazy::MHash(eigenvectors, lower)), eigenvectors_(eigenvectors), lower_(lower) {} diff --git a/torch_xla/csrc/ops/threshold.cpp b/torch_xla/csrc/ops/threshold.cpp index de404deed50..86abe6001a5 100644 --- a/torch_xla/csrc/ops/threshold.cpp +++ b/torch_xla/csrc/ops/threshold.cpp @@ -10,7 +10,7 @@ namespace ops { Threshold::Threshold(const Value& input, float threshold, float value) : Node(ir::OpKind(at::aten::threshold), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(threshold, value)), + /*num_outputs=*/1, torch::lazy::MHash(threshold, value)), threshold_(threshold), value_(value) {} diff --git a/torch_xla/csrc/ops/threshold_backward.cpp b/torch_xla/csrc/ops/threshold_backward.cpp index 8fb9a186920..cb205884ac4 100644 --- a/torch_xla/csrc/ops/threshold_backward.cpp +++ b/torch_xla/csrc/ops/threshold_backward.cpp @@ -11,7 +11,7 @@ namespace ops { ThresholdBackward::ThresholdBackward(const Value& grad_output, const Value& input, float threshold) : Node(ir::OpKind(at::aten::threshold_backward), {grad_output, input}, - input.shape(), /*num_outputs=*/1, xla::util::MHash(threshold)), + input.shape(), /*num_outputs=*/1, torch::lazy::MHash(threshold)), threshold_(threshold) {} NodePtr ThresholdBackward::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp index ebd5484587c..de3864fab97 100644 --- a/torch_xla/csrc/ops/topk.cpp +++ b/torch_xla/csrc/ops/topk.cpp @@ -26,7 +26,7 @@ TopK::TopK(const Value& input, xla::int64 k, xla::int64 dim, bool largest, bool sorted) : Node(ir::OpKind(at::aten::topk), {input}, [&]() { return NodeOutputShape(input, k, dim, largest, sorted); }, - /*num_outputs=*/2, xla::util::MHash(k, dim, largest, sorted)), + /*num_outputs=*/2, torch::lazy::MHash(k, dim, largest, sorted)), k_(k), dim_(dim), largest_(largest), diff --git a/torch_xla/csrc/ops/triangular_solve.cpp b/torch_xla/csrc/ops/triangular_solve.cpp index c58da7eba75..af0ee2ba716 100644 --- a/torch_xla/csrc/ops/triangular_solve.cpp +++ b/torch_xla/csrc/ops/triangular_solve.cpp @@ -80,7 +80,7 @@ TriangularSolve::TriangularSolve(const Value& rhs, const Value& lhs, : Node(ir::OpKind(at::aten::triangular_solve), {rhs, lhs}, [&]() { return NodeOutputShape(rhs, lhs); }, /*num_outputs=*/2, - xla::util::MHash(left_side, lower, transpose, unit_diagonal)), + torch::lazy::MHash(left_side, lower, transpose, unit_diagonal)), left_side_(left_side), lower_(lower), transpose_(transpose), diff --git a/torch_xla/csrc/ops/tril.cpp b/torch_xla/csrc/ops/tril.cpp index f848ccb1ea1..36b1940609f 100644 --- a/torch_xla/csrc/ops/tril.cpp +++ b/torch_xla/csrc/ops/tril.cpp @@ -10,7 +10,7 @@ namespace ops { Tril::Tril(const Value& input, xla::int64 diagonal) : Node(ir::OpKind(at::aten::tril), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(diagonal)), + /*num_outputs=*/1, torch::lazy::MHash(diagonal)), diagonal_(diagonal) {} NodePtr Tril::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/triu.cpp b/torch_xla/csrc/ops/triu.cpp index 330cd44d7c5..6a203ae16aa 100644 --- a/torch_xla/csrc/ops/triu.cpp +++ b/torch_xla/csrc/ops/triu.cpp @@ -10,7 +10,7 @@ namespace ops { Triu::Triu(const Value& input, xla::int64 diagonal) : Node(ir::OpKind(at::aten::triu), {input}, input.shape(), - /*num_outputs=*/1, xla::util::MHash(diagonal)), + /*num_outputs=*/1, torch::lazy::MHash(diagonal)), diagonal_(diagonal) {} NodePtr Triu::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/uniform.cpp b/torch_xla/csrc/ops/uniform.cpp index 81353b39c3f..d2c25128bb7 100644 --- a/torch_xla/csrc/ops/uniform.cpp +++ b/torch_xla/csrc/ops/uniform.cpp @@ -5,6 +5,7 @@ #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/random.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { namespace ir { @@ -13,7 +14,7 @@ namespace ops { Uniform::Uniform(const Value& from, const Value& to, const Value& seed, const xla::Shape& rng_shape) : Node(ir::OpKind(at::aten::uniform), {from, to, seed}, rng_shape, - /*num_outputs=*/1, xla::util::ShapeHash(rng_shape)) {} + /*num_outputs=*/1, torch::lazy::Hash(rng_shape)) {} NodePtr Uniform::Clone(OpList operands) const { return MakeNode(operands.at(0), operands.at(1), operands.at(2), diff --git a/torch_xla/csrc/ops/unselect.cpp b/torch_xla/csrc/ops/unselect.cpp index 8814f8be9a8..aa72e971271 100644 --- a/torch_xla/csrc/ops/unselect.cpp +++ b/torch_xla/csrc/ops/unselect.cpp @@ -15,7 +15,7 @@ namespace ops { Unselect::Unselect(const Value& target, const Value& source, xla::int64 dim, xla::int64 start, xla::int64 end, xla::int64 stride) : Node(xla_unselect, {target, source}, target.shape(), - /*num_outputs=*/1, xla::util::MHash(dim, start, end, stride)), + /*num_outputs=*/1, torch::lazy::MHash(dim, start, end, stride)), dim_(dim), start_(start), end_(end), diff --git a/torch_xla/csrc/ops/unsqueeze.cpp b/torch_xla/csrc/ops/unsqueeze.cpp index ccbd210a470..a956f65d5e0 100644 --- a/torch_xla/csrc/ops/unsqueeze.cpp +++ b/torch_xla/csrc/ops/unsqueeze.cpp @@ -20,7 +20,7 @@ xla::Shape NodeOutputShape(const Value& input, int dim) { Unsqueeze::Unsqueeze(const Value& input, int dim) : Node(ir::OpKind(at::aten::unsqueeze), {input}, [&]() { return NodeOutputShape(input, dim); }, - /*num_outputs=*/1, xla::util::MHash(dim)), + /*num_outputs=*/1, torch::lazy::MHash(dim)), dim_(dim) {} NodePtr Unsqueeze::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/update_slice.cpp b/torch_xla/csrc/ops/update_slice.cpp index 4381a627a1f..80799e56e6c 100644 --- a/torch_xla/csrc/ops/update_slice.cpp +++ b/torch_xla/csrc/ops/update_slice.cpp @@ -6,6 +6,7 @@ #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" +#include "torch_xla/csrc/torch_util.h" namespace torch_xla { namespace ir { @@ -27,7 +28,7 @@ UpdateSlice::UpdateSlice(const Value& input, const Value& source, absl::Span base_indices) : Node(xla_update_slice, {input, source}, [&]() { return NodeOutputShape(input, source, base_indices); }, - /*num_outputs=*/1, xla::util::MHash(base_indices)), + /*num_outputs=*/1, torch::lazy::Hash(base_indices)), base_indices_(base_indices.begin(), base_indices.end()) {} NodePtr UpdateSlice::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/upsample_bilinear2d.cpp b/torch_xla/csrc/ops/upsample_bilinear2d.cpp index 59b4552c2ce..671932a5be8 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d.cpp @@ -17,7 +17,7 @@ UpsampleBilinear::UpsampleBilinear(const Value& input, [&]() { return resize::GetForwardOutputShape2d(input.shape(), output_size); }, - /*num_outputs=*/1, xla::util::MHash(output_size, align_corners)), + /*num_outputs=*/1, torch::lazy::MHash(output_size, align_corners)), output_size_(std::move(output_size)), align_corners_(align_corners) {} diff --git a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp index f5aa4dbd822..ecaf20bd475 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp @@ -18,7 +18,7 @@ UpsampleBilinearBackward::UpsampleBilinearBackward( return resize::GetBackwardOutputShape2d(input.shape(), input_size); }, /*num_outputs=*/1, - xla::util::MHash(output_size, input_size, align_corners)), + torch::lazy::MHash(output_size, input_size, align_corners)), output_size_(std::move(output_size)), input_size_(std::move(input_size)), align_corners_(align_corners) {} diff --git a/torch_xla/csrc/ops/upsample_nearest2d.cpp b/torch_xla/csrc/ops/upsample_nearest2d.cpp index f168dd82272..b478afba719 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d.cpp @@ -16,7 +16,7 @@ UpsampleNearest::UpsampleNearest(const Value& input, [&]() { return resize::GetForwardOutputShape2d(input.shape(), output_size); }, - /*num_outputs=*/1, xla::util::MHash(output_size)), + /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} NodePtr UpsampleNearest::Clone(OpList operands) const { diff --git a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp index c280587c8fa..1900ceb419c 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp @@ -17,7 +17,7 @@ UpsampleNearestBackward::UpsampleNearestBackward( [&]() { return resize::GetBackwardOutputShape2d(input.shape(), input_size); }, - /*num_outputs=*/1, xla::util::MHash(output_size, input_size)), + /*num_outputs=*/1, torch::lazy::MHash(output_size, input_size)), output_size_(std::move(output_size)), input_size_(std::move(input_size)) {} diff --git a/torch_xla/csrc/ops/var.cpp b/torch_xla/csrc/ops/var.cpp index ed8f88c50f2..23a13e5b068 100644 --- a/torch_xla/csrc/ops/var.cpp +++ b/torch_xla/csrc/ops/var.cpp @@ -34,7 +34,7 @@ Var::Var(const Value& input, std::vector dimensions, NodeOutputShape(input, dimensions, correction, keep_reduced_dimensions), /*num_outputs=*/1, - xla::util::MHash(dimensions, correction, keep_reduced_dimensions)), + torch::lazy::MHash(dimensions, correction, keep_reduced_dimensions)), dimensions_(std::move(dimensions)), correction_(correction), keep_reduced_dimensions_(keep_reduced_dimensions) {} diff --git a/torch_xla/csrc/ops/var_mean.cpp b/torch_xla/csrc/ops/var_mean.cpp index 5eb7573ba67..c6270c6ee4a 100644 --- a/torch_xla/csrc/ops/var_mean.cpp +++ b/torch_xla/csrc/ops/var_mean.cpp @@ -39,7 +39,7 @@ VarMean::VarMean(const Value& input, std::vector dimensions, keep_reduced_dimensions); }, /*num_outputs=*/2, - xla::util::MHash(dimensions, correction, keep_reduced_dimensions)), + torch::lazy::MHash(dimensions, correction, keep_reduced_dimensions)), dimensions_(std::move(dimensions)), correction_(correction), keep_reduced_dimensions_(keep_reduced_dimensions) {} diff --git a/torch_xla/csrc/ops/view.cpp b/torch_xla/csrc/ops/view.cpp index d592b29c79c..9183c4a613f 100644 --- a/torch_xla/csrc/ops/view.cpp +++ b/torch_xla/csrc/ops/view.cpp @@ -30,7 +30,7 @@ xla::Shape NodeOutputShape(const Value& input, View::View(const Value& input, std::vector output_size) : Node(ir::OpKind(at::aten::view), {input}, NodeOutputShape(input, output_size), - /*num_outputs=*/1, xla::util::MHash(output_size)), + /*num_outputs=*/1, torch::lazy::MHash(output_size)), output_size_(std::move(output_size)) {} XlaOpVector View::Lower(LoweringContext* loctx) const { diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 05b1b152f29..6ec2737707e 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -39,6 +39,7 @@ #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace { @@ -164,7 +165,7 @@ class XlaDataCacheArena { public: struct TensorHasher { size_t operator()(const at::Tensor& tensor) const { - return xla::util::HashReduce(xla::util::HashCombine( + return torch::lazy::HashReduce(torch::lazy::HashCombine( xla::util::GetEnumValue(tensor.scalar_type()), TensorHash(tensor))); }; }; @@ -1131,7 +1132,7 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( std::unordered_set tensor_ids; // The force_xla_data controls aliasing compilation, so effectively the same // graph with on/off force_xla_data should not match, hash wise. - coll.hash = xla::util::MHash(config.force_xla_data); + coll.hash = torch::lazy::MHash(config.force_xla_data); coll.config = config; coll.device = *unique_device; coll.indices.reserve(tensors.size()); @@ -1150,7 +1151,7 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( if (ir_value) { if (ShouldSyncIrValue(ir_value)) { // Add only tensors which need to be synced. - coll.hash = xla::util::HashCombine(coll.hash, ir_value.hash()); + coll.hash = torch::lazy::HashCombine(coll.hash, ir_value.hash()); coll.indices.push_back(i); } } else if (config.force_xla_data) { @@ -1166,7 +1167,7 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( } // Mix the hash with the resource domain hashes as compile handles are only // valid within a domain (usually a single host). - coll.hash = xla::util::MHash( + coll.hash = torch::lazy::MHash( coll.hash, xla::ComputationClient::Get()->GetResourceDomain(coll.device.ToString())); if (!at_tensors.empty()) { @@ -1181,22 +1182,22 @@ XLATensor::SyncTensorCollection XLATensor::CollectSyncTensors( tensors[at_tensor_index[i]].data()->xla_data = std::move(handles[i]); } } - TF_VLOG(4) << "Tensors graph hash " << xla::util::HexHash(coll.hash) + TF_VLOG(4) << "Tensors graph hash " << torch::lazy::HashToString(coll.hash) << " on device " << coll.device; return coll; } XLATensor::ComputationCache::TypePtr XLATensor::LookupCachedCompile( - const std::vector& tensors, const xla::hash_t& hash) { + const std::vector& tensors, const torch::lazy::hash_t& hash) { ComputationCache::TypePtr cached_computation = GetComputationCache()->Get(hash); if (cached_computation == nullptr) { XLA_COUNTER("UncachedCompile", 1); return nullptr; } - TF_VLOG(5) << "Graph hash " << xla::util::HexHash(hash) + TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(hash) << " is computation hash " - << xla::util::HexHash(xla::util::Hash( + << torch::lazy::HashToString(torch::lazy::Hash( cached_computation->computation->computation() .proto() .SerializeAsString())); @@ -1311,12 +1312,12 @@ std::shared_ptr XLATensor::ScheduleSyncTensorsGraph( auto syncfn = [async, hash = coll->hash]() { xla::ComputationClient::ExecuteComputationOptions options; try { - TF_VLOG(3) << "Executing IR graph hash " << xla::util::HexHash(hash) + TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) << " on device " << async->device << " ..."; auto results = xla::ComputationClient::Get()->ExecuteComputation( *async->cached_computation->computation, async->parameters_data, async->device, options); - TF_VLOG(3) << "Executing IR graph hash " << xla::util::HexHash(hash) + TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) << " on device " << async->device << " done!"; for (size_t i = 0; i < results.size(); ++i) { @@ -1448,13 +1449,13 @@ XLATensor::OpByOpAsync XLATensor::SyncTensorsGraphOpByOp( auto syncfn = [async]() -> int { try { TF_VLOG(3) << "Executing (OpByOp) IR graph hash " - << xla::util::HexHash(async->coll.hash) << " on device " + << torch::lazy::HashToString(async->coll.hash) << " on device " << async->coll.device << " ..."; std::vector results = OpByOpExecutor::Get()->Execute( async->roots, async->coll.device.ToString(), async->devices); TF_VLOG(3) << "Executing (OpByOp) IR graph hash " - << xla::util::HexHash(async->coll.hash) << " on device " + << torch::lazy::HashToString(async->coll.hash) << " on device " << async->coll.device << " done!"; for (size_t i = 0; i < results.size(); ++i) { @@ -1519,7 +1520,7 @@ XLATensor::CompilationResult XLATensor::Compile( [&] { return tensorflow::profiler::TraceMeEncode( "XLATensor::Compile", - {{"graph_hash", xla::util::HexHash(coll.hash)}}); + {{"graph_hash", torch::lazy::HashToString(coll.hash)}}); }, tensorflow::profiler::TraceMeLevel::kInfo); static const bool enable_aliasing = @@ -1571,17 +1572,17 @@ XLATensor::CompilationResult XLATensor::Compile( coll.device.ToString(), devices), &shape}); - TF_VLOG(3) << "Compiling IR graph hash " << xla::util::HexHash(coll.hash) + TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) << " on device " << coll.device << " ..."; std::vector> computations = xla::ComputationClient::Get()->Compile(std::move(instances)); - TF_VLOG(3) << "Compiling IR graph hash " << xla::util::HexHash(coll.hash) + TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) << " on device " << coll.device << " done!"; TF_VLOG(5) - << "Graph hash " << xla::util::HexHash(coll.hash) + << "Graph hash " << torch::lazy::HashToString(coll.hash) << " is computation hash " - << xla::util::HexHash(xla::util::Hash( + << torch::lazy::HashToString(torch::lazy::Hash( computations.front()->computation().proto().SerializeAsString())); XLA_CHECK_EQ(program_shape.parameters_size(), po_data->parameters_data.size()); @@ -1605,10 +1606,10 @@ std::shared_ptr XLATensor::SyncTensorsGraphInternal( &coll.indices); PostOrderData po_data = RunPostOrder(*tensors, coll.indices); - coll.hash = xla::util::HashCombine( - coll.hash, xla::util::Hash(po_data.parameter_sequence)); + coll.hash = torch::lazy::HashCombine( + coll.hash, torch::lazy::Hash(po_data.parameter_sequence)); TF_VLOG(4) << "Parameter sequence graph hash " - << xla::util::HexHash(coll.hash); + << torch::lazy::HashToString(coll.hash); std::shared_ptr async = TryRunCachedSync(tensors, &coll, &po_data); if (async != nullptr) { return async; diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index 11bf844e259..cdc08d4daa2 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1162,7 +1162,7 @@ class XLATensor { SyncTensorsConfig config; std::vector indices; - xla::hash_t hash; + torch::lazy::hash_t hash; std::vector unlocker; Device device; }; @@ -1190,7 +1190,7 @@ class XLATensor { }; using ComputationCache = - xla::util::Cache; + xla::util::Cache; struct Async { Async(SyncTensorCollection* coll, @@ -1371,7 +1371,7 @@ class XLATensor { absl::Span indices); static ComputationCache::TypePtr LookupCachedCompile( - const std::vector& tensors, const xla::hash_t& hash); + const std::vector& tensors, const torch::lazy::hash_t& hash); static std::shared_ptr TryRunCachedSync( std::vector* tensors, SyncTensorCollection* coll, diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 7f22091b44b..78ae60ac117 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -18,6 +18,7 @@ #include "tensorflow/core/lib/bfloat16/bfloat16.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace { @@ -806,34 +807,34 @@ std::vector XlaDataToTensors( return tensors; } -xla::hash_t TensorHash(const at::Tensor& tensor) { +torch::lazy::hash_t TensorHash(const at::Tensor& tensor) { at::Tensor ctensor = tensor.contiguous(); int64_t size = ctensor.numel() * ctensor.element_size(); switch (ctensor.scalar_type()) { case at::ScalarType::Bool: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Byte: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Char: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Short: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Int: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Long: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Float: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Double: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::BFloat16: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::Half: - return xla::util::DataHash(ctensor.data_ptr(), size); + return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::ComplexFloat: - return xla::util::DataHash(ctensor.data_ptr>(), size); + return torch::lazy::DataHash(ctensor.data_ptr>(), size); case at::ScalarType::ComplexDouble: - return xla::util::DataHash(ctensor.data_ptr>(), + return torch::lazy::DataHash(ctensor.data_ptr>(), size); default: XLA_ERROR() << "Unsupported scalar type: " << ctensor.scalar_type(); diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index bf453a62149..5ee94f35247 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -10,6 +10,7 @@ #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "torch/csrc/autograd/variable.h" #include "torch_xla/csrc/device.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { @@ -32,7 +33,7 @@ bool TensorCompare(const at::Tensor& t1, const at::Tensor& t2); xla::ComputationClient::DataPtr TensorToXlaData(const at::Tensor& tensor, const Device& device); -xla::hash_t TensorHash(const at::Tensor& tensor); +torch::lazy::hash_t TensorHash(const at::Tensor& tensor); // Retrieves the device data handles by parallel uploading data onto the // corresponding devices. diff --git a/torch_xla/csrc/torch_util.cpp b/torch_xla/csrc/torch_util.cpp index 2d7f4b1138e..8d0797fb007 100644 --- a/torch_xla/csrc/torch_util.cpp +++ b/torch_xla/csrc/torch_util.cpp @@ -1,6 +1,7 @@ #include "torch_xla/csrc/torch_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" +#include "tensorflow/compiler/xla/xla_client/xla_util.h" namespace torch_xla { @@ -33,3 +34,13 @@ at::Tensor UnwrapNumber(const at::Tensor& tensor, at::ScalarType dtype) { } } // namespace torch_xla + +namespace torch { +namespace lazy { +torch::lazy::hash_t Hash(const xla::Shape& shape) { + auto shape_hash = xla::util::ShapeHash(shape); + return c10::uint128(absl::Uint128High64(shape_hash), + absl::Uint128Low64(shape_hash)); +} +} // namespace lazy +} // namespace torch diff --git a/torch_xla/csrc/torch_util.h b/torch_xla/csrc/torch_util.h index 07ed342cb1d..cd7b2f13a54 100644 --- a/torch_xla/csrc/torch_util.h +++ b/torch_xla/csrc/torch_util.h @@ -4,6 +4,8 @@ #include #include +#include "tensorflow/compiler/xla/shape.h" +#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { // Makes a deep copy of an ATEN tensor. @@ -40,3 +42,16 @@ inline bool IsDefined(const c10::optional& tensor) { } } // namespace torch_xla + +namespace torch { +namespace lazy { +// Adapters that provide torch::lazy Hash functions for xla types +torch::lazy::hash_t Hash(const xla::Shape& shape); + +template +torch::lazy::hash_t Hash(absl::Span values) { + return torch::lazy::ContainerHash(values); +} + +} // namespace lazy +} // namespace torch \ No newline at end of file From 0c8bd01ae804e668cc489e648a97472ab68b8db4 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 4 Oct 2021 18:08:58 -0700 Subject: [PATCH 2/9] Torch pin --- torch_patches/.torch_pin | 1 + 1 file changed, 1 insertion(+) create mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin new file mode 100644 index 00000000000..20691d0c6a7 --- /dev/null +++ b/torch_patches/.torch_pin @@ -0,0 +1 @@ +#65635 From 66d6914909e7b7dea402af19ec0bd82173221880 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 5 Oct 2021 11:41:09 -0700 Subject: [PATCH 3/9] clang format --- torch_xla/csrc/debug_util.cpp | 2 +- torch_xla/csrc/ir.cpp | 4 ++-- torch_xla/csrc/ir.h | 10 +++++---- torch_xla/csrc/op_by_op_executor.cpp | 2 +- torch_xla/csrc/ops/all_reduce.cpp | 2 +- torch_xla/csrc/ops/all_to_all.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/cast.cpp | 2 +- .../ops/convolution_backward_overrideable.cpp | 2 +- .../csrc/ops/convolution_overrideable.cpp | 4 ++-- torch_xla/csrc/ops/generic.cpp | 3 ++- torch_xla/csrc/ops/generic.h | 3 ++- torch_xla/csrc/ops/generic_slice.cpp | 4 +++- torch_xla/csrc/ops/max_pool_nd.cpp | 2 +- torch_xla/csrc/ops/max_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/mean.cpp | 2 +- torch_xla/csrc/ops/nll_loss.cpp | 17 +++++++------- torch_xla/csrc/ops/nll_loss2d.cpp | 17 +++++++------- torch_xla/csrc/ops/nll_loss2d_backward.cpp | 19 ++++++++-------- torch_xla/csrc/ops/nll_loss_backward.cpp | 19 ++++++++-------- torch_xla/csrc/ops/ops.h | 17 +++++++------- torch_xla/csrc/ops/prod.cpp | 2 +- torch_xla/csrc/ops/sum.cpp | 2 +- torch_xla/csrc/tensor.cpp | 22 +++++++++++-------- torch_xla/csrc/tensor.h | 3 ++- torch_xla/csrc/tensor_util.cpp | 7 +++--- torch_xla/csrc/tensor_util.h | 2 +- 28 files changed, 97 insertions(+), 80 deletions(-) diff --git a/torch_xla/csrc/debug_util.cpp b/torch_xla/csrc/debug_util.cpp index 8820499a3d3..1964680e872 100644 --- a/torch_xla/csrc/debug_util.cpp +++ b/torch_xla/csrc/debug_util.cpp @@ -10,12 +10,12 @@ #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_client/unique.h" +#include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/device.h" #include "torch_xla/csrc/ir.h" #include "torch_xla/csrc/ir_dump_util.h" #include "torch_xla/csrc/ir_util.h" #include "torch_xla/csrc/python_util.h" -#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace { diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 40d22936111..586e332e4d0 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -8,8 +8,8 @@ #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" #include "tensorflow/compiler/xla/xla_client/util.h" -#include "torch_xla/csrc/lowering_context.h" #include "torch/csrc/lazy/core/hash.h" +#include "torch_xla/csrc/lowering_context.h" namespace torch_xla { namespace ir { @@ -249,7 +249,7 @@ XlaOpVector Node::Lower(LoweringContext* loctx) const { } torch::lazy::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape, - torch::lazy::hash_t hash_seed) { + torch::lazy::hash_t hash_seed) { torch::lazy::hash_t h = torch::lazy::HashCombine(op.hash(), torch::lazy::Hash(shape.ToString())); return torch::lazy::HashCombine(h, hash_seed); diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index e0e26cdc461..eea0e0c4bec 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -16,8 +16,8 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/xla_client/types.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "torch_xla/csrc/python_util.h" #include "torch/csrc/lazy/core/hash.h" +#include "torch_xla/csrc/python_util.h" namespace torch_xla { namespace ir { @@ -180,10 +180,12 @@ class Node { // Same as the constructor above, but the shape is generated by a function, // only if needed (shape cache miss). Node(OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); + size_t num_outputs = 1, + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); // Contructor used to create leaf nodes. - Node(OpKind op, xla::Shape shape, size_t num_outputs, torch::lazy::hash_t hash_seed); + Node(OpKind op, xla::Shape shape, size_t num_outputs, + torch::lazy::hash_t hash_seed); virtual ~Node(); @@ -245,7 +247,7 @@ class Node { xla::Shape GetOpShape(const std::function& shape_fn) const; static torch::lazy::hash_t GetOpHash(OpKind op, const xla::Shape& shape, - torch::lazy::hash_t hash_seed); + torch::lazy::hash_t hash_seed); static std::vector GetFrameInfo(); diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp index 36621a7116f..a2c0ffcafff 100644 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ b/torch_xla/csrc/op_by_op_executor.cpp @@ -50,7 +50,7 @@ torch::lazy::hash_t ComputeNodeKey( const auto& operands = node->operands(); for (size_t i = 0; i < operands.size(); ++i) { key = torch::lazy::HashCombine(key, torch::lazy::Hash(GetParameterShape( - operands[i], *input_shapes[i]))); + operands[i], *input_shapes[i]))); } key = torch::lazy::HashCombine(key, torch::lazy::Hash(node->shape())); return torch::lazy::HashCombine(key, node->node_hash()); diff --git a/torch_xla/csrc/ops/all_reduce.cpp b/torch_xla/csrc/ops/all_reduce.cpp index c83c63b2c0d..727d0160005 100644 --- a/torch_xla/csrc/ops/all_reduce.cpp +++ b/torch_xla/csrc/ops/all_reduce.cpp @@ -38,7 +38,7 @@ AllReduce::AllReduce(AllReduceType reduce_type, [&]() { return NodeOutputShape(operands, token); }, /*num_outputs=*/operands.size() + 1, torch::lazy::MHash(xla::util::GetEnumValue(reduce_type), scale, - groups)), + groups)), reduce_type_(reduce_type), scale_(scale), groups_(std::move(groups)) {} diff --git a/torch_xla/csrc/ops/all_to_all.cpp b/torch_xla/csrc/ops/all_to_all.cpp index 2c4f92d4a8d..c7a30146040 100644 --- a/torch_xla/csrc/ops/all_to_all.cpp +++ b/torch_xla/csrc/ops/all_to_all.cpp @@ -38,7 +38,7 @@ AllToAll::AllToAll(const Value& input, const Value& token, }, /*num_outputs=*/2, torch::lazy::MHash(split_dimension, concat_dimension, split_count, - groups)), + groups)), split_dimension_(split_dimension), concat_dimension_(concat_dimension), split_count_(split_count), diff --git a/torch_xla/csrc/ops/avg_pool_nd.cpp b/torch_xla/csrc/ops/avg_pool_nd.cpp index 06fca993b8e..d52515aadd8 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd.cpp @@ -57,7 +57,7 @@ AvgPoolNd::AvgPoolNd(const Value& input, xla::int64 spatial_dim_count, }, /*num_outputs=*/1, torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, - ceil_mode, count_include_pad)), + ceil_mode, count_include_pad)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), stride_(std::move(stride)), diff --git a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp index dc3c65b4485..62710693fde 100644 --- a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp @@ -56,7 +56,7 @@ AvgPoolNdBackward::AvgPoolNdBackward( }, /*num_outputs=*/1, torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, - ceil_mode, count_include_pad)), + ceil_mode, count_include_pad)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), stride_(std::move(stride)), diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index 471a7f21c05..e5aa55b8f76 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -36,7 +36,7 @@ Cast::Cast(const Value& input, at::ScalarType dtype, MakeXlaPrimitiveType(dtype, /*device=*/nullptr)), /*num_outputs=*/1, torch::lazy::MHash(101, static_cast(dtype), - OptionalOr(stype, -1))), + OptionalOr(stype, -1))), type_(MakeXlaPrimitiveType(dtype, /*device=*/nullptr)), dtype_(dtype), stype_(stype) {} diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp index 118aee1224c..462dccfc8df 100644 --- a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp @@ -48,7 +48,7 @@ ConvolutionBackwardOverrideable::ConvolutionBackwardOverrideable( }, /*num_outputs=*/3, torch::lazy::MHash(stride, padding, dilation, transposed, - output_padding, groups)), + output_padding, groups)), stride_(std::move(stride)), padding_(std::move(padding)), dilation_(std::move(dilation)), diff --git a/torch_xla/csrc/ops/convolution_overrideable.cpp b/torch_xla/csrc/ops/convolution_overrideable.cpp index 6581c3b5fe6..4da6a512528 100644 --- a/torch_xla/csrc/ops/convolution_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_overrideable.cpp @@ -46,7 +46,7 @@ ConvolutionOverrideable::ConvolutionOverrideable( }, /*num_outputs=*/1, torch::lazy::MHash(stride, padding, dilation, transposed, - output_padding, groups)), + output_padding, groups)), stride_(std::move(stride)), padding_(std::move(padding)), dilation_(std::move(dilation)), @@ -65,7 +65,7 @@ ConvolutionOverrideable::ConvolutionOverrideable( }, /*num_outputs=*/1, torch::lazy::MHash(stride, padding, dilation, transposed, - output_padding, groups)), + output_padding, groups)), stride_(std::move(stride)), padding_(std::move(padding)), dilation_(std::move(dilation)), diff --git a/torch_xla/csrc/ops/generic.cpp b/torch_xla/csrc/ops/generic.cpp index 85fa5532576..0a774fa15a6 100644 --- a/torch_xla/csrc/ops/generic.cpp +++ b/torch_xla/csrc/ops/generic.cpp @@ -7,7 +7,8 @@ namespace ir { namespace ops { Generic::Generic(OpKind op, absl::Span operands, xla::Shape shape, - LowerFn lower_fn, size_t num_outputs, torch::lazy::hash_t hash_seed) + LowerFn lower_fn, size_t num_outputs, + torch::lazy::hash_t hash_seed) : Node(std::move(op), operands, std::move(shape), num_outputs, hash_seed), lower_fn_(std::move(lower_fn)), hash_seed_(hash_seed) {} diff --git a/torch_xla/csrc/ops/generic.h b/torch_xla/csrc/ops/generic.h index 394fd86aebc..f426426e5b1 100644 --- a/torch_xla/csrc/ops/generic.h +++ b/torch_xla/csrc/ops/generic.h @@ -21,7 +21,8 @@ class Generic : public Node { Generic(OpKind op, absl::Span operands, const std::function& shape_fn, LowerFn lower_fn, - size_t num_outputs = 1, torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); + size_t num_outputs = 1, + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9); Generic(OpKind op, xla::Shape shape, LowerFn lower_fn, size_t num_outputs, torch::lazy::hash_t hash_seed); diff --git a/torch_xla/csrc/ops/generic_slice.cpp b/torch_xla/csrc/ops/generic_slice.cpp index b27e39aedd1..6cddcc91475 100644 --- a/torch_xla/csrc/ops/generic_slice.cpp +++ b/torch_xla/csrc/ops/generic_slice.cpp @@ -30,7 +30,9 @@ GenericSlice::GenericSlice(const Value& input, absl::Span sizes) : Node(xla_generic_slice, {input}, [&]() { return NodeOutputShape(input, base_indices, sizes); }, - /*num_outputs=*/1, torch::lazy::MHash(torch::lazy::Hash(base_indices), torch::lazy::Hash(sizes))), + /*num_outputs=*/1, + torch::lazy::MHash(torch::lazy::Hash(base_indices), + torch::lazy::Hash(sizes))), base_indices_(base_indices.begin(), base_indices.end()), sizes_(sizes.begin(), sizes.end()) {} diff --git a/torch_xla/csrc/ops/max_pool_nd.cpp b/torch_xla/csrc/ops/max_pool_nd.cpp index 2d1046200f6..7906aef2698 100644 --- a/torch_xla/csrc/ops/max_pool_nd.cpp +++ b/torch_xla/csrc/ops/max_pool_nd.cpp @@ -52,7 +52,7 @@ MaxPoolNd::MaxPoolNd(const Value& input, xla::int64 spatial_dim_count, }, /*num_outputs=*/2, torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, - ceil_mode)), + ceil_mode)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), stride_(std::move(stride)), diff --git a/torch_xla/csrc/ops/max_pool_nd_backward.cpp b/torch_xla/csrc/ops/max_pool_nd_backward.cpp index d8959146eea..663d6869d62 100644 --- a/torch_xla/csrc/ops/max_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_pool_nd_backward.cpp @@ -54,7 +54,7 @@ MaxPoolNdBackward::MaxPoolNdBackward( }, /*num_outputs=*/1, torch::lazy::MHash(spatial_dim_count, kernel_size, stride, padding, - ceil_mode)), + ceil_mode)), spatial_dim_count_(spatial_dim_count), kernel_size_(std::move(kernel_size)), stride_(std::move(stride)), diff --git a/torch_xla/csrc/ops/mean.cpp b/torch_xla/csrc/ops/mean.cpp index b11885aef5a..021d5662a65 100644 --- a/torch_xla/csrc/ops/mean.cpp +++ b/torch_xla/csrc/ops/mean.cpp @@ -46,7 +46,7 @@ Mean::Mean(const Value& input, std::vector dimensions, }, /*num_outputs=*/1, torch::lazy::MHash(dimensions, keep_reduced_dimensions, - OptionalOr(dtype, -1))), + OptionalOr(dtype, -1))), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), dtype_(dtype) {} diff --git a/torch_xla/csrc/ops/nll_loss.cpp b/torch_xla/csrc/ops/nll_loss.cpp index b25dc3f00ad..e8cf9efb74b 100644 --- a/torch_xla/csrc/ops/nll_loss.cpp +++ b/torch_xla/csrc/ops/nll_loss.cpp @@ -36,14 +36,15 @@ xla::Shape NodeOutputShape(const Value& logits, const Value& labels, NllLoss::NllLoss(const Value& logits, const Value& labels, const absl::optional& weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss), - xla::util::GetValuesVector({logits, labels}, {&weight}), - [&]() { - return NodeOutputShape(logits, labels, weight, reduction, - ignore_index); - }, - /*num_outputs=*/1, - torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + : Node( + ir::OpKind(at::aten::nll_loss), + xla::util::GetValuesVector({logits, labels}, {&weight}), + [&]() { + return NodeOutputShape(logits, labels, weight, reduction, + ignore_index); + }, + /*num_outputs=*/1, + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/nll_loss2d.cpp b/torch_xla/csrc/ops/nll_loss2d.cpp index aec10e4949f..53f54c7f43f 100644 --- a/torch_xla/csrc/ops/nll_loss2d.cpp +++ b/torch_xla/csrc/ops/nll_loss2d.cpp @@ -36,14 +36,15 @@ xla::Shape NodeOutputShape(const Value& logits, const Value& labels, NllLoss2d::NllLoss2d(const Value& logits, const Value& labels, const absl::optional& weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss2d), - xla::util::GetValuesVector({logits, labels}, {&weight}), - [&]() { - return NodeOutputShape(logits, labels, weight, reduction, - ignore_index); - }, - /*num_outputs=*/1, - torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + : Node( + ir::OpKind(at::aten::nll_loss2d), + xla::util::GetValuesVector({logits, labels}, {&weight}), + [&]() { + return NodeOutputShape(logits, labels, weight, reduction, + ignore_index); + }, + /*num_outputs=*/1, + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/nll_loss2d_backward.cpp b/torch_xla/csrc/ops/nll_loss2d_backward.cpp index 110372fa9e1..cce21bb3ed0 100644 --- a/torch_xla/csrc/ops/nll_loss2d_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss2d_backward.cpp @@ -44,15 +44,16 @@ NllLoss2dBackward::NllLoss2dBackward(const Value& grad_output, const absl::optional& weight, const absl::optional& total_weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss2d_backward), - xla::util::GetValuesVector({grad_output, logits, labels}, - {&weight, &total_weight}), - [&]() { - return NodeOutputShape(grad_output, logits, labels, weight, - total_weight, reduction, ignore_index); - }, - /*num_outputs=*/1, - torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + : Node( + ir::OpKind(at::aten::nll_loss2d_backward), + xla::util::GetValuesVector({grad_output, logits, labels}, + {&weight, &total_weight}), + [&]() { + return NodeOutputShape(grad_output, logits, labels, weight, + total_weight, reduction, ignore_index); + }, + /*num_outputs=*/1, + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/nll_loss_backward.cpp b/torch_xla/csrc/ops/nll_loss_backward.cpp index b8f104e9d6f..097fcda4286 100644 --- a/torch_xla/csrc/ops/nll_loss_backward.cpp +++ b/torch_xla/csrc/ops/nll_loss_backward.cpp @@ -44,15 +44,16 @@ NllLossBackward::NllLossBackward(const Value& grad_output, const Value& logits, const absl::optional& weight, const absl::optional& total_weight, ReductionMode reduction, int ignore_index) - : Node(ir::OpKind(at::aten::nll_loss_backward), - xla::util::GetValuesVector({grad_output, logits, labels}, - {&weight, &total_weight}), - [&]() { - return NodeOutputShape(grad_output, logits, labels, weight, - total_weight, reduction, ignore_index); - }, - /*num_outputs=*/1, - torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), + : Node( + ir::OpKind(at::aten::nll_loss_backward), + xla::util::GetValuesVector({grad_output, logits, labels}, + {&weight, &total_weight}), + [&]() { + return NodeOutputShape(grad_output, logits, labels, weight, + total_weight, reduction, ignore_index); + }, + /*num_outputs=*/1, + torch::lazy::MHash(xla::util::GetEnumValue(reduction), ignore_index)), reduction_(reduction), ignore_index_(ignore_index) {} diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 4772ecf9e27..9cbbda022fe 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -25,18 +25,19 @@ inline NodePtr ConstantOp(xla::Literal value) { return MakeNode(std::move(value)); } -inline NodePtr GenericOp(OpKind op, absl::Span operands, - xla::Shape shape, Generic::LowerFn lower_fn, - size_t num_outputs = 1, - torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { +inline NodePtr GenericOp( + OpKind op, absl::Span operands, xla::Shape shape, + Generic::LowerFn lower_fn, size_t num_outputs = 1, + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { return MakeNode(std::move(op), operands, std::move(shape), std::move(lower_fn), num_outputs, hash_seed); } -inline NodePtr GenericOp(OpKind op, absl::Span operands, - const std::function& shape_fn, - Generic::LowerFn lower_fn, size_t num_outputs = 1, - torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { +inline NodePtr GenericOp( + OpKind op, absl::Span operands, + const std::function& shape_fn, Generic::LowerFn lower_fn, + size_t num_outputs = 1, + torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { return MakeNode(std::move(op), operands, shape_fn, std::move(lower_fn), num_outputs, hash_seed); } diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index 380357fd866..8e276b1259c 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -52,7 +52,7 @@ Prod::Prod(const Value& input, std::vector dimensions, }, /*num_outputs=*/1, torch::lazy::MHash(dimensions, keep_reduced_dimensions, - OptionalOr(dtype, -1))), + OptionalOr(dtype, -1))), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), dtype_(dtype) {} diff --git a/torch_xla/csrc/ops/sum.cpp b/torch_xla/csrc/ops/sum.cpp index 6edd1c5df17..c42178c2cd0 100644 --- a/torch_xla/csrc/ops/sum.cpp +++ b/torch_xla/csrc/ops/sum.cpp @@ -44,7 +44,7 @@ Sum::Sum(const Value& input, std::vector dimensions, }, /*num_outputs=*/1, torch::lazy::MHash(dimensions, keep_reduced_dimensions, - OptionalOr(dtype, -1))), + OptionalOr(dtype, -1))), dimensions_(std::move(dimensions)), keep_reduced_dimensions_(keep_reduced_dimensions), dtype_(dtype) {} diff --git a/torch_xla/csrc/tensor.cpp b/torch_xla/csrc/tensor.cpp index 6ec2737707e..1225cd10469 100644 --- a/torch_xla/csrc/tensor.cpp +++ b/torch_xla/csrc/tensor.cpp @@ -25,6 +25,7 @@ #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/profiler/lib/traceme.h" #include "torch/csrc/autograd/variable.h" +#include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/debug_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/ir_dump_util.h" @@ -39,7 +40,6 @@ #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/torch_util.h" -#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace { @@ -1312,13 +1312,15 @@ std::shared_ptr XLATensor::ScheduleSyncTensorsGraph( auto syncfn = [async, hash = coll->hash]() { xla::ComputationClient::ExecuteComputationOptions options; try { - TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) - << " on device " << async->device << " ..."; + TF_VLOG(3) << "Executing IR graph hash " + << torch::lazy::HashToString(hash) << " on device " + << async->device << " ..."; auto results = xla::ComputationClient::Get()->ExecuteComputation( *async->cached_computation->computation, async->parameters_data, async->device, options); - TF_VLOG(3) << "Executing IR graph hash " << torch::lazy::HashToString(hash) - << " on device " << async->device << " done!"; + TF_VLOG(3) << "Executing IR graph hash " + << torch::lazy::HashToString(hash) << " on device " + << async->device << " done!"; for (size_t i = 0; i < results.size(); ++i) { if (async->tensors_data[i] != nullptr) { @@ -1572,13 +1574,15 @@ XLATensor::CompilationResult XLATensor::Compile( coll.device.ToString(), devices), &shape}); - TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) - << " on device " << coll.device << " ..."; + TF_VLOG(3) << "Compiling IR graph hash " + << torch::lazy::HashToString(coll.hash) << " on device " + << coll.device << " ..."; std::vector> computations = xla::ComputationClient::Get()->Compile(std::move(instances)); - TF_VLOG(3) << "Compiling IR graph hash " << torch::lazy::HashToString(coll.hash) - << " on device " << coll.device << " done!"; + TF_VLOG(3) << "Compiling IR graph hash " + << torch::lazy::HashToString(coll.hash) << " on device " + << coll.device << " done!"; TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(coll.hash) << " is computation hash " diff --git a/torch_xla/csrc/tensor.h b/torch_xla/csrc/tensor.h index cdc08d4daa2..16d17360745 100644 --- a/torch_xla/csrc/tensor.h +++ b/torch_xla/csrc/tensor.h @@ -1190,7 +1190,8 @@ class XLATensor { }; using ComputationCache = - xla::util::Cache; + xla::util::Cache; struct Async { Async(SyncTensorCollection* coll, diff --git a/torch_xla/csrc/tensor_util.cpp b/torch_xla/csrc/tensor_util.cpp index 78ae60ac117..205eb8dd3cc 100644 --- a/torch_xla/csrc/tensor_util.cpp +++ b/torch_xla/csrc/tensor_util.cpp @@ -16,9 +16,9 @@ #include "tensorflow/compiler/xla/xla_client/thread_pool.h" #include "tensorflow/compiler/xla/xla_client/util.h" #include "tensorflow/core/lib/bfloat16/bfloat16.h" +#include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/layout_manager.h" -#include "torch/csrc/lazy/core/hash.h" namespace torch_xla { namespace { @@ -832,10 +832,11 @@ torch::lazy::hash_t TensorHash(const at::Tensor& tensor) { case at::ScalarType::Half: return torch::lazy::DataHash(ctensor.data_ptr(), size); case at::ScalarType::ComplexFloat: - return torch::lazy::DataHash(ctensor.data_ptr>(), size); + return torch::lazy::DataHash(ctensor.data_ptr>(), + size); case at::ScalarType::ComplexDouble: return torch::lazy::DataHash(ctensor.data_ptr>(), - size); + size); default: XLA_ERROR() << "Unsupported scalar type: " << ctensor.scalar_type(); } diff --git a/torch_xla/csrc/tensor_util.h b/torch_xla/csrc/tensor_util.h index 5ee94f35247..943fed8711d 100644 --- a/torch_xla/csrc/tensor_util.h +++ b/torch_xla/csrc/tensor_util.h @@ -9,8 +9,8 @@ #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_client/computation_client.h" #include "torch/csrc/autograd/variable.h" -#include "torch_xla/csrc/device.h" #include "torch/csrc/lazy/core/hash.h" +#include "torch_xla/csrc/device.h" namespace torch_xla { From 1ee8dd490772406ece2c8d094d3480cc8248a95b Mon Sep 17 00:00:00 2001 From: Will Constable Date: Tue, 5 Oct 2021 17:37:20 -0700 Subject: [PATCH 4/9] Update .torch_pin --- torch_patches/.torch_pin | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin index 20691d0c6a7..f08d56e5bed 100644 --- a/torch_patches/.torch_pin +++ b/torch_patches/.torch_pin @@ -1 +1 @@ -#65635 +#66181 From f33930456ca04310c0f69758892dadf1c52114eb Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 6 Oct 2021 10:43:07 -0700 Subject: [PATCH 5/9] Clean up unneeded include util.h - remove from places that only used it for util::Hash --- torch_xla/csrc/computation.cpp | 2 +- torch_xla/csrc/ir.cpp | 2 +- torch_xla/csrc/op_by_op_executor.cpp | 2 +- torch_xla/csrc/ops/adaptive_avg_pool2d.cpp | 2 +- torch_xla/csrc/ops/adaptive_avg_pool3d.cpp | 2 +- torch_xla/csrc/ops/adaptive_max_pool2d.cpp | 2 +- torch_xla/csrc/ops/all.cpp | 2 +- torch_xla/csrc/ops/all_to_all.cpp | 2 +- torch_xla/csrc/ops/amax.cpp | 2 +- torch_xla/csrc/ops/amin.cpp | 2 +- torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp | 2 +- torch_xla/csrc/ops/amp_update_scale.cpp | 2 +- torch_xla/csrc/ops/any.cpp | 2 +- torch_xla/csrc/ops/arg_max.cpp | 2 +- torch_xla/csrc/ops/arg_min.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd.cpp | 2 +- torch_xla/csrc/ops/avg_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/bernoulli.cpp | 2 +- torch_xla/csrc/ops/cast.cpp | 2 +- torch_xla/csrc/ops/cat.cpp | 2 +- torch_xla/csrc/ops/cholesky.cpp | 2 +- torch_xla/csrc/ops/collective_permute.cpp | 2 +- torch_xla/csrc/ops/constant_pad_nd.cpp | 2 +- torch_xla/csrc/ops/convolution_backward_overrideable.cpp | 2 +- torch_xla/csrc/ops/convolution_overrideable.cpp | 2 +- torch_xla/csrc/ops/diagonal.cpp | 2 +- torch_xla/csrc/ops/diagonal_view_update.cpp | 2 +- torch_xla/csrc/ops/discrete_uniform.cpp | 2 +- torch_xla/csrc/ops/expand.cpp | 2 +- torch_xla/csrc/ops/exponential.cpp | 2 +- torch_xla/csrc/ops/flip.cpp | 2 +- torch_xla/csrc/ops/gather.cpp | 2 +- torch_xla/csrc/ops/generic_slice.cpp | 2 +- torch_xla/csrc/ops/get_dimensions_size.cpp | 2 +- torch_xla/csrc/ops/hardtanh_backward.cpp | 2 +- torch_xla/csrc/ops/index_get.cpp | 2 +- torch_xla/csrc/ops/index_put.cpp | 2 +- torch_xla/csrc/ops/index_select.cpp | 2 +- torch_xla/csrc/ops/kth_value.cpp | 2 +- torch_xla/csrc/ops/leaky_relu.cpp | 2 +- torch_xla/csrc/ops/leaky_relu_backward.cpp | 2 +- torch_xla/csrc/ops/linear_interpolation.cpp | 2 +- torch_xla/csrc/ops/log_softmax.cpp | 2 +- torch_xla/csrc/ops/log_softmax_backward.cpp | 2 +- torch_xla/csrc/ops/logsumexp.cpp | 2 +- torch_xla/csrc/ops/masked_fill.cpp | 2 +- torch_xla/csrc/ops/masked_scatter.cpp | 2 +- torch_xla/csrc/ops/masked_select.cpp | 2 +- torch_xla/csrc/ops/max_in_dim.cpp | 2 +- torch_xla/csrc/ops/max_pool_nd.cpp | 2 +- torch_xla/csrc/ops/max_pool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/max_unpool_nd.cpp | 2 +- torch_xla/csrc/ops/max_unpool_nd_backward.cpp | 2 +- torch_xla/csrc/ops/mean.cpp | 2 +- torch_xla/csrc/ops/min_in_dim.cpp | 2 +- torch_xla/csrc/ops/native_batch_norm_backward.cpp | 2 +- torch_xla/csrc/ops/native_batch_norm_forward.cpp | 2 +- torch_xla/csrc/ops/nms.cpp | 2 +- torch_xla/csrc/ops/nonzero.cpp | 2 +- torch_xla/csrc/ops/normal.cpp | 2 +- torch_xla/csrc/ops/not_supported.cpp | 2 +- torch_xla/csrc/ops/permute.cpp | 2 +- torch_xla/csrc/ops/prod.cpp | 2 +- torch_xla/csrc/ops/put.cpp | 2 +- torch_xla/csrc/ops/qr.cpp | 2 +- torch_xla/csrc/ops/reflection_pad2d.cpp | 2 +- torch_xla/csrc/ops/reflection_pad2d_backward.cpp | 2 +- torch_xla/csrc/ops/repeat.cpp | 2 +- torch_xla/csrc/ops/replication_pad.cpp | 2 +- torch_xla/csrc/ops/replication_pad_backward.cpp | 2 +- torch_xla/csrc/ops/resize.cpp | 2 +- torch_xla/csrc/ops/rrelu_with_noise.cpp | 2 +- torch_xla/csrc/ops/rrelu_with_noise_backward.cpp | 2 +- torch_xla/csrc/ops/scalar.cpp | 2 +- torch_xla/csrc/ops/scatter.cpp | 2 +- torch_xla/csrc/ops/scatter_add.cpp | 2 +- torch_xla/csrc/ops/select.cpp | 2 +- torch_xla/csrc/ops/softmax.cpp | 2 +- torch_xla/csrc/ops/softmax_backward.cpp | 2 +- torch_xla/csrc/ops/split.cpp | 2 +- torch_xla/csrc/ops/squeeze.cpp | 2 +- torch_xla/csrc/ops/stack.cpp | 2 +- torch_xla/csrc/ops/std.cpp | 2 +- torch_xla/csrc/ops/std_mean.cpp | 2 +- torch_xla/csrc/ops/sum.cpp | 2 +- torch_xla/csrc/ops/symeig.cpp | 2 +- torch_xla/csrc/ops/threshold.cpp | 2 +- torch_xla/csrc/ops/threshold_backward.cpp | 2 +- torch_xla/csrc/ops/topk.cpp | 2 +- torch_xla/csrc/ops/triangular_solve.cpp | 2 +- torch_xla/csrc/ops/tril.cpp | 2 +- torch_xla/csrc/ops/triu.cpp | 2 +- torch_xla/csrc/ops/uniform.cpp | 2 +- torch_xla/csrc/ops/unselect.cpp | 2 +- torch_xla/csrc/ops/unsqueeze.cpp | 2 +- torch_xla/csrc/ops/update_slice.cpp | 2 +- torch_xla/csrc/ops/upsample_bilinear2d.cpp | 2 +- torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp | 2 +- torch_xla/csrc/ops/upsample_nearest2d.cpp | 2 +- torch_xla/csrc/ops/upsample_nearest2d_backward.cpp | 2 +- torch_xla/csrc/ops/var.cpp | 2 +- torch_xla/csrc/ops/var_mean.cpp | 2 +- torch_xla/csrc/ops/view.cpp | 2 +- torch_xla/csrc/resize_ops.cpp | 2 +- 104 files changed, 104 insertions(+), 104 deletions(-) diff --git a/torch_xla/csrc/computation.cpp b/torch_xla/csrc/computation.cpp index 5a83e06be07..96510e2c0a4 100644 --- a/torch_xla/csrc/computation.cpp +++ b/torch_xla/csrc/computation.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/computation.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + namespace torch_xla { diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 586e332e4d0..d79d3c388f3 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -7,7 +7,7 @@ #include "tensorflow/compiler/xla/xla_client/cache.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp index a2c0ffcafff..cb5e1e4372e 100644 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ b/torch_xla/csrc/op_by_op_executor.cpp @@ -8,7 +8,7 @@ #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/metrics.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp index 5b2795dbe7d..82babc3d469 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/adaptive_avg_pool2d.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp index 6c311ac1cf2..9dd32ed47fe 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/adaptive_avg_pool3d.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp index e8f76e36aca..6579f52614e 100644 --- a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/adaptive_max_pool2d.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/all.cpp b/torch_xla/csrc/ops/all.cpp index eb009c29a99..b167a38f185 100644 --- a/torch_xla/csrc/ops/all.cpp +++ b/torch_xla/csrc/ops/all.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/all.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/all_to_all.cpp b/torch_xla/csrc/ops/all_to_all.cpp index c7a30146040..064406c0efe 100644 --- a/torch_xla/csrc/ops/all_to_all.cpp +++ b/torch_xla/csrc/ops/all_to_all.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/amax.cpp b/torch_xla/csrc/ops/amax.cpp index cda79db570e..b6a9bdf5d21 100644 --- a/torch_xla/csrc/ops/amax.cpp +++ b/torch_xla/csrc/ops/amax.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/amax.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/amin.cpp b/torch_xla/csrc/ops/amin.cpp index 51350ea54e7..c6fede8714e 100644 --- a/torch_xla/csrc/ops/amin.cpp +++ b/torch_xla/csrc/ops/amin.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/amin.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp index 19da4ea2881..1ab741408b5 100644 --- a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/tensor_util.h" diff --git a/torch_xla/csrc/ops/amp_update_scale.cpp b/torch_xla/csrc/ops/amp_update_scale.cpp index 67b46f5b7a3..6570c057c74 100644 --- a/torch_xla/csrc/ops/amp_update_scale.cpp +++ b/torch_xla/csrc/ops/amp_update_scale.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/amp_update_scale.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/tensor_util.h" diff --git a/torch_xla/csrc/ops/any.cpp b/torch_xla/csrc/ops/any.cpp index b5db12ade8b..d19dae9286c 100644 --- a/torch_xla/csrc/ops/any.cpp +++ b/torch_xla/csrc/ops/any.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/any.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/arg_max.cpp b/torch_xla/csrc/ops/arg_max.cpp index 56618008047..eb6815e97be 100644 --- a/torch_xla/csrc/ops/arg_max.cpp +++ b/torch_xla/csrc/ops/arg_max.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/arg_max.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/arg_min.cpp b/torch_xla/csrc/ops/arg_min.cpp index 512dc5c2be7..381ce9a5d18 100644 --- a/torch_xla/csrc/ops/arg_min.cpp +++ b/torch_xla/csrc/ops/arg_min.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/arg_min.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/avg_pool_nd.cpp b/torch_xla/csrc/ops/avg_pool_nd.cpp index d52515aadd8..38a5dad7922 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp index 62710693fde..c304da1d784 100644 --- a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/avg_pool_nd_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/bernoulli.cpp b/torch_xla/csrc/ops/bernoulli.cpp index 63ba971d6c7..43e59e9bd41 100644 --- a/torch_xla/csrc/ops/bernoulli.cpp +++ b/torch_xla/csrc/ops/bernoulli.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/bernoulli.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index e5aa55b8f76..a3b814f20b1 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/cast.h" #include "tensorflow/compiler/xla/primitive_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/cat.cpp b/torch_xla/csrc/ops/cat.cpp index 3b836202bc3..e0c8fa1ad35 100644 --- a/torch_xla/csrc/ops/cat.cpp +++ b/torch_xla/csrc/ops/cat.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/cat.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/cholesky.cpp b/torch_xla/csrc/ops/cholesky.cpp index 1d20082d4b0..6e3351c47d3 100644 --- a/torch_xla/csrc/ops/cholesky.cpp +++ b/torch_xla/csrc/ops/cholesky.cpp @@ -2,7 +2,7 @@ #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/collective_permute.cpp b/torch_xla/csrc/ops/collective_permute.cpp index 2eec2663d6f..9f6c1211d13 100644 --- a/torch_xla/csrc/ops/collective_permute.cpp +++ b/torch_xla/csrc/ops/collective_permute.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/constant_pad_nd.cpp b/torch_xla/csrc/ops/constant_pad_nd.cpp index 72364ac7dec..59607d339a0 100644 --- a/torch_xla/csrc/ops/constant_pad_nd.cpp +++ b/torch_xla/csrc/ops/constant_pad_nd.cpp @@ -3,7 +3,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp index 462dccfc8df..fb2f16b2bf0 100644 --- a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convolution.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/convolution_overrideable.cpp b/torch_xla/csrc/ops/convolution_overrideable.cpp index 4da6a512528..bb91c5ea49e 100644 --- a/torch_xla/csrc/ops/convolution_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_overrideable.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convolution.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/diagonal.cpp b/torch_xla/csrc/ops/diagonal.cpp index 5b43f5abf6c..a3f4874691b 100644 --- a/torch_xla/csrc/ops/diagonal.cpp +++ b/torch_xla/csrc/ops/diagonal.cpp @@ -4,7 +4,7 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" diff --git a/torch_xla/csrc/ops/diagonal_view_update.cpp b/torch_xla/csrc/ops/diagonal_view_update.cpp index 219a3953f20..7d6b1dea208 100644 --- a/torch_xla/csrc/ops/diagonal_view_update.cpp +++ b/torch_xla/csrc/ops/diagonal_view_update.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/diagonal_view_update.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/discrete_uniform.cpp b/torch_xla/csrc/ops/discrete_uniform.cpp index c74a16bc314..b0b72f48843 100644 --- a/torch_xla/csrc/ops/discrete_uniform.cpp +++ b/torch_xla/csrc/ops/discrete_uniform.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/discrete_uniform.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/expand.cpp b/torch_xla/csrc/ops/expand.cpp index ab63e16bdc6..63f555306f7 100644 --- a/torch_xla/csrc/ops/expand.cpp +++ b/torch_xla/csrc/ops/expand.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/expand.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/exponential.cpp b/torch_xla/csrc/ops/exponential.cpp index fafd6c20512..d326aa2475a 100644 --- a/torch_xla/csrc/ops/exponential.cpp +++ b/torch_xla/csrc/ops/exponential.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/exponential.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/flip.cpp b/torch_xla/csrc/ops/flip.cpp index 29d7b0c31d3..0e66bb1421a 100644 --- a/torch_xla/csrc/ops/flip.cpp +++ b/torch_xla/csrc/ops/flip.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/flip.h" #include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/gather.cpp b/torch_xla/csrc/ops/gather.cpp index 8b6119853e7..7471b6448da 100644 --- a/torch_xla/csrc/ops/gather.cpp +++ b/torch_xla/csrc/ops/gather.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/gather.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/generic_slice.cpp b/torch_xla/csrc/ops/generic_slice.cpp index 6cddcc91475..686a24ddaa2 100644 --- a/torch_xla/csrc/ops/generic_slice.cpp +++ b/torch_xla/csrc/ops/generic_slice.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/generic_slice.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/get_dimensions_size.cpp b/torch_xla/csrc/ops/get_dimensions_size.cpp index 29ae72f21fa..68da5d3d77b 100644 --- a/torch_xla/csrc/ops/get_dimensions_size.cpp +++ b/torch_xla/csrc/ops/get_dimensions_size.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/hardtanh_backward.cpp b/torch_xla/csrc/ops/hardtanh_backward.cpp index a7de77ac8e6..5239e81d75e 100644 --- a/torch_xla/csrc/ops/hardtanh_backward.cpp +++ b/torch_xla/csrc/ops/hardtanh_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/hardtanh_backward.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/scalar.h" diff --git a/torch_xla/csrc/ops/index_get.cpp b/torch_xla/csrc/ops/index_get.cpp index fb4df1b096c..496d2892e74 100644 --- a/torch_xla/csrc/ops/index_get.cpp +++ b/torch_xla/csrc/ops/index_get.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/index_get.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/index_put.cpp b/torch_xla/csrc/ops/index_put.cpp index d0cbad3a6cc..39c857cc10f 100644 --- a/torch_xla/csrc/ops/index_put.cpp +++ b/torch_xla/csrc/ops/index_put.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/index_put.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/index_select.cpp b/torch_xla/csrc/ops/index_select.cpp index 9e2c95dc499..62fb975cd3f 100644 --- a/torch_xla/csrc/ops/index_select.cpp +++ b/torch_xla/csrc/ops/index_select.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/index_select.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/kth_value.cpp b/torch_xla/csrc/ops/kth_value.cpp index 8c7e4de4fd8..ac2ef1ecd6c 100644 --- a/torch_xla/csrc/ops/kth_value.cpp +++ b/torch_xla/csrc/ops/kth_value.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/kth_value.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/leaky_relu.cpp b/torch_xla/csrc/ops/leaky_relu.cpp index 0bcbababe14..8ff87a458d4 100644 --- a/torch_xla/csrc/ops/leaky_relu.cpp +++ b/torch_xla/csrc/ops/leaky_relu.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/leaky_relu.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/leaky_relu_backward.cpp b/torch_xla/csrc/ops/leaky_relu_backward.cpp index 3103eecc46d..6ca5e2180b7 100644 --- a/torch_xla/csrc/ops/leaky_relu_backward.cpp +++ b/torch_xla/csrc/ops/leaky_relu_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/leaky_relu_backward.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/linear_interpolation.cpp b/torch_xla/csrc/ops/linear_interpolation.cpp index bb98c63d70b..26e7fc084cc 100644 --- a/torch_xla/csrc/ops/linear_interpolation.cpp +++ b/torch_xla/csrc/ops/linear_interpolation.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/linear_interpolation.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/log_softmax.cpp b/torch_xla/csrc/ops/log_softmax.cpp index 34fec5fbabe..b540fa6684e 100644 --- a/torch_xla/csrc/ops/log_softmax.cpp +++ b/torch_xla/csrc/ops/log_softmax.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/log_softmax.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/log_softmax_backward.cpp b/torch_xla/csrc/ops/log_softmax_backward.cpp index 5039ed3143c..bfed6d12bb8 100644 --- a/torch_xla/csrc/ops/log_softmax_backward.cpp +++ b/torch_xla/csrc/ops/log_softmax_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/log_softmax_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/logsumexp.cpp b/torch_xla/csrc/ops/logsumexp.cpp index 300bf273719..cd34fb20c50 100644 --- a/torch_xla/csrc/ops/logsumexp.cpp +++ b/torch_xla/csrc/ops/logsumexp.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/logsumexp.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/masked_fill.cpp b/torch_xla/csrc/ops/masked_fill.cpp index bcef4d4234d..2d4822735c0 100644 --- a/torch_xla/csrc/ops/masked_fill.cpp +++ b/torch_xla/csrc/ops/masked_fill.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/masked_fill.h" #include "tensorflow/compiler/xla/client/lib/constants.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/scalar.h" diff --git a/torch_xla/csrc/ops/masked_scatter.cpp b/torch_xla/csrc/ops/masked_scatter.cpp index 6d9a3212705..25af6ce9df8 100644 --- a/torch_xla/csrc/ops/masked_scatter.cpp +++ b/torch_xla/csrc/ops/masked_scatter.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/masked_scatter.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/masked_select.cpp b/torch_xla/csrc/ops/masked_select.cpp index 8dc8d5a6d4c..d458532ee8c 100644 --- a/torch_xla/csrc/ops/masked_select.cpp +++ b/torch_xla/csrc/ops/masked_select.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/masked_select.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/max_in_dim.cpp b/torch_xla/csrc/ops/max_in_dim.cpp index 44e23b1dad5..3b1b24e995c 100644 --- a/torch_xla/csrc/ops/max_in_dim.cpp +++ b/torch_xla/csrc/ops/max_in_dim.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/max_in_dim.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/max_pool_nd.cpp b/torch_xla/csrc/ops/max_pool_nd.cpp index 7906aef2698..2d3eb5b6886 100644 --- a/torch_xla/csrc/ops/max_pool_nd.cpp +++ b/torch_xla/csrc/ops/max_pool_nd.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/max_pool_nd.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_pool_nd_backward.cpp b/torch_xla/csrc/ops/max_pool_nd_backward.cpp index 663d6869d62..4fc1a02d761 100644 --- a/torch_xla/csrc/ops/max_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_pool_nd_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/max_pool_nd_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_unpool_nd.cpp b/torch_xla/csrc/ops/max_unpool_nd.cpp index 770656f3095..24984325a93 100644 --- a/torch_xla/csrc/ops/max_unpool_nd.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/max_unpool_nd.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp index 79e530924b9..c553b7ec6db 100644 --- a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/max_unpool_nd_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/mean.cpp b/torch_xla/csrc/ops/mean.cpp index 021d5662a65..13c98ffd3e7 100644 --- a/torch_xla/csrc/ops/mean.cpp +++ b/torch_xla/csrc/ops/mean.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/mean.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/min_in_dim.cpp b/torch_xla/csrc/ops/min_in_dim.cpp index ec5e0ba24b9..5b9b6181dc1 100644 --- a/torch_xla/csrc/ops/min_in_dim.cpp +++ b/torch_xla/csrc/ops/min_in_dim.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/min_in_dim.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/native_batch_norm_backward.cpp b/torch_xla/csrc/ops/native_batch_norm_backward.cpp index d806f8cd7b5..0bdad07b892 100644 --- a/torch_xla/csrc/ops/native_batch_norm_backward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/native_batch_norm_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/batch_norm.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/native_batch_norm_forward.cpp b/torch_xla/csrc/ops/native_batch_norm_forward.cpp index a7417589835..1e28d81f877 100644 --- a/torch_xla/csrc/ops/native_batch_norm_forward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_forward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/native_batch_norm_forward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/batch_norm.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/nms.cpp b/torch_xla/csrc/ops/nms.cpp index aed5f4f9179..2d7b2f4d5a8 100644 --- a/torch_xla/csrc/ops/nms.cpp +++ b/torch_xla/csrc/ops/nms.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/nms.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/nms_op.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/nonzero.cpp b/torch_xla/csrc/ops/nonzero.cpp index f590ecbb208..16a71c777d9 100644 --- a/torch_xla/csrc/ops/nonzero.cpp +++ b/torch_xla/csrc/ops/nonzero.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/nonzero.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/normal.cpp b/torch_xla/csrc/ops/normal.cpp index db5c4d7e339..b9f55194c44 100644 --- a/torch_xla/csrc/ops/normal.cpp +++ b/torch_xla/csrc/ops/normal.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/normal.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/random.h" diff --git a/torch_xla/csrc/ops/not_supported.cpp b/torch_xla/csrc/ops/not_supported.cpp index 02dfb41b6fc..0aa978c6e4c 100644 --- a/torch_xla/csrc/ops/not_supported.cpp +++ b/torch_xla/csrc/ops/not_supported.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/not_supported.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/permute.cpp b/torch_xla/csrc/ops/permute.cpp index 8195dbd69c7..da88bc3c850 100644 --- a/torch_xla/csrc/ops/permute.cpp +++ b/torch_xla/csrc/ops/permute.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/permute.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index 8e276b1259c..10070ee44ba 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/prod.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/put.cpp b/torch_xla/csrc/ops/put.cpp index 31e6747cdc8..afa9fe74cfb 100644 --- a/torch_xla/csrc/ops/put.cpp +++ b/torch_xla/csrc/ops/put.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/put.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/qr.cpp b/torch_xla/csrc/ops/qr.cpp index 1a887b96f73..e9e9f39d582 100644 --- a/torch_xla/csrc/ops/qr.cpp +++ b/torch_xla/csrc/ops/qr.cpp @@ -3,7 +3,7 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/qr.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/reflection_pad2d.cpp b/torch_xla/csrc/ops/reflection_pad2d.cpp index 134374e2b2c..ffbe02bfa48 100644 --- a/torch_xla/csrc/ops/reflection_pad2d.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/reflection_pad2d.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp index e5b36b1a4bc..120a622f298 100644 --- a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/reflection_pad2d_backward.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/repeat.cpp b/torch_xla/csrc/ops/repeat.cpp index f1600e577b1..88f41452a68 100644 --- a/torch_xla/csrc/ops/repeat.cpp +++ b/torch_xla/csrc/ops/repeat.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/repeat.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/replication_pad.cpp b/torch_xla/csrc/ops/replication_pad.cpp index 44872689321..785cef41d64 100644 --- a/torch_xla/csrc/ops/replication_pad.cpp +++ b/torch_xla/csrc/ops/replication_pad.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/replication_pad.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/replication_pad_backward.cpp b/torch_xla/csrc/ops/replication_pad_backward.cpp index c86031f24ac..d5447783d8a 100644 --- a/torch_xla/csrc/ops/replication_pad_backward.cpp +++ b/torch_xla/csrc/ops/replication_pad_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/replication_pad_backward.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/resize.cpp b/torch_xla/csrc/ops/resize.cpp index e708fdbb150..93f47d2e079 100644 --- a/torch_xla/csrc/ops/resize.cpp +++ b/torch_xla/csrc/ops/resize.cpp @@ -2,7 +2,7 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/rrelu_with_noise.cpp b/torch_xla/csrc/ops/rrelu_with_noise.cpp index 18d327de96b..18c23f24adc 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/rrelu_with_noise.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp index b66ad7bce0a..33dcb6b8dcd 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/rrelu_with_noise_backward.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/scalar.h" diff --git a/torch_xla/csrc/ops/scalar.cpp b/torch_xla/csrc/ops/scalar.cpp index bad4945502f..f5c11da192f 100644 --- a/torch_xla/csrc/ops/scalar.cpp +++ b/torch_xla/csrc/ops/scalar.cpp @@ -5,7 +5,7 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/scatter.cpp b/torch_xla/csrc/ops/scatter.cpp index acaf850bfbb..a112a87d70a 100644 --- a/torch_xla/csrc/ops/scatter.cpp +++ b/torch_xla/csrc/ops/scatter.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/scatter.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/scatter_add.cpp b/torch_xla/csrc/ops/scatter_add.cpp index a692c208bd4..b5a33bef9ec 100644 --- a/torch_xla/csrc/ops/scatter_add.cpp +++ b/torch_xla/csrc/ops/scatter_add.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/scatter_add.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/select.cpp b/torch_xla/csrc/ops/select.cpp index 15236c210cb..0c4574d0c1e 100644 --- a/torch_xla/csrc/ops/select.cpp +++ b/torch_xla/csrc/ops/select.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/select.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/softmax.cpp b/torch_xla/csrc/ops/softmax.cpp index a2453ce3c2c..3aa6b202cf6 100644 --- a/torch_xla/csrc/ops/softmax.cpp +++ b/torch_xla/csrc/ops/softmax.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/softmax.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/softmax_backward.cpp b/torch_xla/csrc/ops/softmax_backward.cpp index 37b8c75ca7e..dee56f2359c 100644 --- a/torch_xla/csrc/ops/softmax_backward.cpp +++ b/torch_xla/csrc/ops/softmax_backward.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/softmax_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/split.cpp b/torch_xla/csrc/ops/split.cpp index fc822a38353..970f9c146f1 100644 --- a/torch_xla/csrc/ops/split.cpp +++ b/torch_xla/csrc/ops/split.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/squeeze.cpp b/torch_xla/csrc/ops/squeeze.cpp index 0a53c1fe0a2..28fd30eedb8 100644 --- a/torch_xla/csrc/ops/squeeze.cpp +++ b/torch_xla/csrc/ops/squeeze.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/squeeze.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/stack.cpp b/torch_xla/csrc/ops/stack.cpp index b4ceb7078ec..798d2c78f38 100644 --- a/torch_xla/csrc/ops/stack.cpp +++ b/torch_xla/csrc/ops/stack.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/stack.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/std.cpp b/torch_xla/csrc/ops/std.cpp index 67cb1f2b447..d223effd760 100644 --- a/torch_xla/csrc/ops/std.cpp +++ b/torch_xla/csrc/ops/std.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/std.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/std_mean.cpp b/torch_xla/csrc/ops/std_mean.cpp index 7b8774b2f19..034039719b9 100644 --- a/torch_xla/csrc/ops/std_mean.cpp +++ b/torch_xla/csrc/ops/std_mean.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/std_mean.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/sum.cpp b/torch_xla/csrc/ops/sum.cpp index c42178c2cd0..aac52824531 100644 --- a/torch_xla/csrc/ops/sum.cpp +++ b/torch_xla/csrc/ops/sum.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/sum.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index a6622a4f361..472c87297aa 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -3,7 +3,7 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/threshold.cpp b/torch_xla/csrc/ops/threshold.cpp index 86abe6001a5..c6ec0bd3bd2 100644 --- a/torch_xla/csrc/ops/threshold.cpp +++ b/torch_xla/csrc/ops/threshold.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/threshold.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/threshold_backward.cpp b/torch_xla/csrc/ops/threshold_backward.cpp index cb205884ac4..715993852da 100644 --- a/torch_xla/csrc/ops/threshold_backward.cpp +++ b/torch_xla/csrc/ops/threshold_backward.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/threshold_backward.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp index de3864fab97..4bfcbd8dfaa 100644 --- a/torch_xla/csrc/ops/topk.cpp +++ b/torch_xla/csrc/ops/topk.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/topk.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/triangular_solve.cpp b/torch_xla/csrc/ops/triangular_solve.cpp index af0ee2ba716..fd61a52c1db 100644 --- a/torch_xla/csrc/ops/triangular_solve.cpp +++ b/torch_xla/csrc/ops/triangular_solve.cpp @@ -2,7 +2,7 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/tril.cpp b/torch_xla/csrc/ops/tril.cpp index 36b1940609f..3646936189d 100644 --- a/torch_xla/csrc/ops/tril.cpp +++ b/torch_xla/csrc/ops/tril.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/tril.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" diff --git a/torch_xla/csrc/ops/triu.cpp b/torch_xla/csrc/ops/triu.cpp index 6a203ae16aa..b0f41bcace2 100644 --- a/torch_xla/csrc/ops/triu.cpp +++ b/torch_xla/csrc/ops/triu.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/triu.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" diff --git a/torch_xla/csrc/ops/uniform.cpp b/torch_xla/csrc/ops/uniform.cpp index d2c25128bb7..6cf9a031ae0 100644 --- a/torch_xla/csrc/ops/uniform.cpp +++ b/torch_xla/csrc/ops/uniform.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/uniform.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/unselect.cpp b/torch_xla/csrc/ops/unselect.cpp index aa72e971271..03173a88108 100644 --- a/torch_xla/csrc/ops/unselect.cpp +++ b/torch_xla/csrc/ops/unselect.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/unselect.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/unsqueeze.cpp b/torch_xla/csrc/ops/unsqueeze.cpp index a956f65d5e0..9ceb692495b 100644 --- a/torch_xla/csrc/ops/unsqueeze.cpp +++ b/torch_xla/csrc/ops/unsqueeze.cpp @@ -1,6 +1,6 @@ #include "torch_xla/csrc/ops/unsqueeze.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/update_slice.cpp b/torch_xla/csrc/ops/update_slice.cpp index 80799e56e6c..a44080be387 100644 --- a/torch_xla/csrc/ops/update_slice.cpp +++ b/torch_xla/csrc/ops/update_slice.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/update_slice.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/upsample_bilinear2d.cpp b/torch_xla/csrc/ops/upsample_bilinear2d.cpp index 671932a5be8..191ed780879 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp index ecaf20bd475..ef586c3e21b 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/upsample_nearest2d.cpp b/torch_xla/csrc/ops/upsample_nearest2d.cpp index b478afba719..24de7144bcd 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp index 1900ceb419c..b2fe2589704 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/var.cpp b/torch_xla/csrc/ops/var.cpp index 23a13e5b068..fc8954136ad 100644 --- a/torch_xla/csrc/ops/var.cpp +++ b/torch_xla/csrc/ops/var.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/var.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/var_mean.cpp b/torch_xla/csrc/ops/var_mean.cpp index c6270c6ee4a..73cc4013675 100644 --- a/torch_xla/csrc/ops/var_mean.cpp +++ b/torch_xla/csrc/ops/var_mean.cpp @@ -1,7 +1,7 @@ #include "torch_xla/csrc/ops/var_mean.h" #include "absl/strings/str_join.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/view.cpp b/torch_xla/csrc/ops/view.cpp index 9183c4a613f..7a3e2e9fba9 100644 --- a/torch_xla/csrc/ops/view.cpp +++ b/torch_xla/csrc/ops/view.cpp @@ -2,7 +2,7 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index 29dbfed5992..0aad199cd03 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -6,7 +6,7 @@ #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" -#include "tensorflow/compiler/xla/xla_client/util.h" + #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/shape_builder.h" From 85c14e712f829b190f88afea181ae5749c7cb594 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 6 Oct 2021 13:36:03 -0700 Subject: [PATCH 6/9] Comment explaining cast to uint32_t --- torch_xla/csrc/ops/ops.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch_xla/csrc/ops/ops.h b/torch_xla/csrc/ops/ops.h index 9cbbda022fe..e19d8e7abc7 100644 --- a/torch_xla/csrc/ops/ops.h +++ b/torch_xla/csrc/ops/ops.h @@ -28,6 +28,7 @@ inline NodePtr ConstantOp(xla::Literal value) { inline NodePtr GenericOp( OpKind op, absl::Span operands, xla::Shape shape, Generic::LowerFn lower_fn, size_t num_outputs = 1, + // cast to uint32_t to avoid ambiguous constructor of uint128 torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { return MakeNode(std::move(op), operands, std::move(shape), std::move(lower_fn), num_outputs, hash_seed); @@ -37,6 +38,7 @@ inline NodePtr GenericOp( OpKind op, absl::Span operands, const std::function& shape_fn, Generic::LowerFn lower_fn, size_t num_outputs = 1, + // cast to uint32_t to avoid ambiguous constructor of uint128 torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9) { return MakeNode(std::move(op), operands, shape_fn, std::move(lower_fn), num_outputs, hash_seed); From b67be8e34decf2be8aded531830a3fafb506f2c8 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 6 Oct 2021 13:51:38 -0700 Subject: [PATCH 7/9] Fix MHash for absl::span --- torch_xla/csrc/ops/generic_slice.cpp | 4 +--- torch_xla/csrc/torch_util.h | 7 +++++++ 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/torch_xla/csrc/ops/generic_slice.cpp b/torch_xla/csrc/ops/generic_slice.cpp index 686a24ddaa2..4a3810337a5 100644 --- a/torch_xla/csrc/ops/generic_slice.cpp +++ b/torch_xla/csrc/ops/generic_slice.cpp @@ -30,9 +30,7 @@ GenericSlice::GenericSlice(const Value& input, absl::Span sizes) : Node(xla_generic_slice, {input}, [&]() { return NodeOutputShape(input, base_indices, sizes); }, - /*num_outputs=*/1, - torch::lazy::MHash(torch::lazy::Hash(base_indices), - torch::lazy::Hash(sizes))), + /*num_outputs=*/1, torch::lazy::MHash(base_indices, sizes)), base_indices_(base_indices.begin(), base_indices.end()), sizes_(sizes.begin(), sizes.end()) {} diff --git a/torch_xla/csrc/torch_util.h b/torch_xla/csrc/torch_util.h index cd7b2f13a54..a5ee2b7a243 100644 --- a/torch_xla/csrc/torch_util.h +++ b/torch_xla/csrc/torch_util.h @@ -53,5 +53,12 @@ torch::lazy::hash_t Hash(absl::Span values) { return torch::lazy::ContainerHash(values); } +// When specializing Hash(T) also specialize MHash(T, ...) since +// torch::lazy::MHash template won't be aware of the Hash(T) here +template +hash_t MHash(absl::Span value, Targs... Fargs) { + return HashCombine(Hash(value), MHash(Fargs...)); +} + } // namespace lazy } // namespace torch \ No newline at end of file From a2e9e79d0652e662dbd28ebaa9dd04cb45515c5c Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 6 Oct 2021 14:27:25 -0700 Subject: [PATCH 8/9] Fix clang-format --- torch_xla/csrc/computation.cpp | 1 - torch_xla/csrc/ir.cpp | 1 - torch_xla/csrc/op_by_op_executor.cpp | 1 - torch_xla/csrc/ops/adaptive_avg_pool2d.cpp | 1 - torch_xla/csrc/ops/adaptive_avg_pool3d.cpp | 1 - torch_xla/csrc/ops/adaptive_max_pool2d.cpp | 1 - torch_xla/csrc/ops/all.cpp | 1 - torch_xla/csrc/ops/all_to_all.cpp | 1 - torch_xla/csrc/ops/amax.cpp | 1 - torch_xla/csrc/ops/amin.cpp | 1 - torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp | 1 - torch_xla/csrc/ops/amp_update_scale.cpp | 1 - torch_xla/csrc/ops/any.cpp | 1 - torch_xla/csrc/ops/arg_max.cpp | 1 - torch_xla/csrc/ops/arg_min.cpp | 1 - torch_xla/csrc/ops/avg_pool_nd.cpp | 1 - torch_xla/csrc/ops/avg_pool_nd_backward.cpp | 1 - torch_xla/csrc/ops/bernoulli.cpp | 1 - torch_xla/csrc/ops/cast.cpp | 1 - torch_xla/csrc/ops/cat.cpp | 1 - torch_xla/csrc/ops/cholesky.cpp | 1 - torch_xla/csrc/ops/collective_permute.cpp | 1 - torch_xla/csrc/ops/constant_pad_nd.cpp | 1 - torch_xla/csrc/ops/convolution_backward_overrideable.cpp | 1 - torch_xla/csrc/ops/convolution_overrideable.cpp | 1 - torch_xla/csrc/ops/diagonal.cpp | 1 - torch_xla/csrc/ops/diagonal_view_update.cpp | 1 - torch_xla/csrc/ops/discrete_uniform.cpp | 1 - torch_xla/csrc/ops/expand.cpp | 1 - torch_xla/csrc/ops/exponential.cpp | 1 - torch_xla/csrc/ops/flip.cpp | 1 - torch_xla/csrc/ops/gather.cpp | 1 - torch_xla/csrc/ops/generic_slice.cpp | 1 - torch_xla/csrc/ops/get_dimensions_size.cpp | 1 - torch_xla/csrc/ops/hardtanh_backward.cpp | 1 - torch_xla/csrc/ops/index_get.cpp | 1 - torch_xla/csrc/ops/index_put.cpp | 1 - torch_xla/csrc/ops/index_select.cpp | 1 - torch_xla/csrc/ops/kth_value.cpp | 1 - torch_xla/csrc/ops/leaky_relu.cpp | 1 - torch_xla/csrc/ops/leaky_relu_backward.cpp | 1 - torch_xla/csrc/ops/linear_interpolation.cpp | 1 - torch_xla/csrc/ops/log_softmax.cpp | 1 - torch_xla/csrc/ops/log_softmax_backward.cpp | 1 - torch_xla/csrc/ops/logsumexp.cpp | 1 - torch_xla/csrc/ops/masked_fill.cpp | 1 - torch_xla/csrc/ops/masked_scatter.cpp | 1 - torch_xla/csrc/ops/masked_select.cpp | 1 - torch_xla/csrc/ops/max_in_dim.cpp | 1 - torch_xla/csrc/ops/max_pool_nd.cpp | 1 - torch_xla/csrc/ops/max_pool_nd_backward.cpp | 1 - torch_xla/csrc/ops/max_unpool_nd.cpp | 1 - torch_xla/csrc/ops/max_unpool_nd_backward.cpp | 1 - torch_xla/csrc/ops/mean.cpp | 1 - torch_xla/csrc/ops/min_in_dim.cpp | 1 - torch_xla/csrc/ops/native_batch_norm_backward.cpp | 1 - torch_xla/csrc/ops/native_batch_norm_forward.cpp | 1 - torch_xla/csrc/ops/nms.cpp | 1 - torch_xla/csrc/ops/nonzero.cpp | 1 - torch_xla/csrc/ops/normal.cpp | 1 - torch_xla/csrc/ops/not_supported.cpp | 1 - torch_xla/csrc/ops/permute.cpp | 1 - torch_xla/csrc/ops/prod.cpp | 1 - torch_xla/csrc/ops/put.cpp | 1 - torch_xla/csrc/ops/qr.cpp | 1 - torch_xla/csrc/ops/reflection_pad2d.cpp | 1 - torch_xla/csrc/ops/reflection_pad2d_backward.cpp | 1 - torch_xla/csrc/ops/repeat.cpp | 1 - torch_xla/csrc/ops/replication_pad.cpp | 1 - torch_xla/csrc/ops/replication_pad_backward.cpp | 1 - torch_xla/csrc/ops/resize.cpp | 1 - torch_xla/csrc/ops/rrelu_with_noise.cpp | 1 - torch_xla/csrc/ops/rrelu_with_noise_backward.cpp | 1 - torch_xla/csrc/ops/scalar.cpp | 1 - torch_xla/csrc/ops/scatter.cpp | 1 - torch_xla/csrc/ops/scatter_add.cpp | 1 - torch_xla/csrc/ops/select.cpp | 1 - torch_xla/csrc/ops/softmax.cpp | 1 - torch_xla/csrc/ops/softmax_backward.cpp | 1 - torch_xla/csrc/ops/split.cpp | 1 - torch_xla/csrc/ops/squeeze.cpp | 1 - torch_xla/csrc/ops/stack.cpp | 1 - torch_xla/csrc/ops/std.cpp | 1 - torch_xla/csrc/ops/std_mean.cpp | 1 - torch_xla/csrc/ops/sum.cpp | 1 - torch_xla/csrc/ops/symeig.cpp | 1 - torch_xla/csrc/ops/threshold.cpp | 1 - torch_xla/csrc/ops/threshold_backward.cpp | 1 - torch_xla/csrc/ops/topk.cpp | 1 - torch_xla/csrc/ops/triangular_solve.cpp | 1 - torch_xla/csrc/ops/tril.cpp | 1 - torch_xla/csrc/ops/triu.cpp | 1 - torch_xla/csrc/ops/uniform.cpp | 1 - torch_xla/csrc/ops/unselect.cpp | 1 - torch_xla/csrc/ops/unsqueeze.cpp | 1 - torch_xla/csrc/ops/update_slice.cpp | 1 - torch_xla/csrc/ops/upsample_bilinear2d.cpp | 1 - torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp | 1 - torch_xla/csrc/ops/upsample_nearest2d.cpp | 1 - torch_xla/csrc/ops/upsample_nearest2d_backward.cpp | 1 - torch_xla/csrc/ops/var.cpp | 1 - torch_xla/csrc/ops/var_mean.cpp | 1 - torch_xla/csrc/ops/view.cpp | 1 - torch_xla/csrc/resize_ops.cpp | 1 - 104 files changed, 104 deletions(-) diff --git a/torch_xla/csrc/computation.cpp b/torch_xla/csrc/computation.cpp index 96510e2c0a4..8ac94130538 100644 --- a/torch_xla/csrc/computation.cpp +++ b/torch_xla/csrc/computation.cpp @@ -2,7 +2,6 @@ #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - namespace torch_xla { Computation::Computation(std::string name, xla::XlaComputation computation) diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index d79d3c388f3..2a4d5cdfaf2 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -7,7 +7,6 @@ #include "tensorflow/compiler/xla/xla_client/cache.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" - #include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/op_by_op_executor.cpp b/torch_xla/csrc/op_by_op_executor.cpp index cb5e1e4372e..69b5df572b8 100644 --- a/torch_xla/csrc/op_by_op_executor.cpp +++ b/torch_xla/csrc/op_by_op_executor.cpp @@ -8,7 +8,6 @@ #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/metrics.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" - #include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "torch/csrc/lazy/core/hash.h" #include "torch_xla/csrc/device.h" diff --git a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp index 82babc3d469..45d59ba0b5a 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool2d.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/adaptive_avg_pool2d.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp index 9dd32ed47fe..3d4a72748ad 100644 --- a/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp +++ b/torch_xla/csrc/ops/adaptive_avg_pool3d.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/adaptive_avg_pool3d.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp index 6579f52614e..a393057bd10 100644 --- a/torch_xla/csrc/ops/adaptive_max_pool2d.cpp +++ b/torch_xla/csrc/ops/adaptive_max_pool2d.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/adaptive_max_pool2d.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/all.cpp b/torch_xla/csrc/ops/all.cpp index b167a38f185..ed6af7a85ce 100644 --- a/torch_xla/csrc/ops/all.cpp +++ b/torch_xla/csrc/ops/all.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/all.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/all_to_all.cpp b/torch_xla/csrc/ops/all_to_all.cpp index 064406c0efe..903935e256e 100644 --- a/torch_xla/csrc/ops/all_to_all.cpp +++ b/torch_xla/csrc/ops/all_to_all.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/amax.cpp b/torch_xla/csrc/ops/amax.cpp index b6a9bdf5d21..18c6f4160bd 100644 --- a/torch_xla/csrc/ops/amax.cpp +++ b/torch_xla/csrc/ops/amax.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/amax.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/amin.cpp b/torch_xla/csrc/ops/amin.cpp index c6fede8714e..ab87179b60e 100644 --- a/torch_xla/csrc/ops/amin.cpp +++ b/torch_xla/csrc/ops/amin.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/amin.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp index 1ab741408b5..bab99cbd9b2 100644 --- a/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp +++ b/torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/amp_foreach_non_finite_check_and_unscale.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/tensor_util.h" diff --git a/torch_xla/csrc/ops/amp_update_scale.cpp b/torch_xla/csrc/ops/amp_update_scale.cpp index 6570c057c74..a4f3fceb710 100644 --- a/torch_xla/csrc/ops/amp_update_scale.cpp +++ b/torch_xla/csrc/ops/amp_update_scale.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/amp_update_scale.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" #include "torch_xla/csrc/tensor_util.h" diff --git a/torch_xla/csrc/ops/any.cpp b/torch_xla/csrc/ops/any.cpp index d19dae9286c..2399a9ee771 100644 --- a/torch_xla/csrc/ops/any.cpp +++ b/torch_xla/csrc/ops/any.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/any.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/arg_max.cpp b/torch_xla/csrc/ops/arg_max.cpp index eb6815e97be..3865096abbf 100644 --- a/torch_xla/csrc/ops/arg_max.cpp +++ b/torch_xla/csrc/ops/arg_max.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/arg_max.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/arg_min.cpp b/torch_xla/csrc/ops/arg_min.cpp index 381ce9a5d18..84ae2557a73 100644 --- a/torch_xla/csrc/ops/arg_min.cpp +++ b/torch_xla/csrc/ops/arg_min.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/arg_min.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/avg_pool_nd.cpp b/torch_xla/csrc/ops/avg_pool_nd.cpp index 38a5dad7922..83391a122f5 100644 --- a/torch_xla/csrc/ops/avg_pool_nd.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp index c304da1d784..554aae79f3f 100644 --- a/torch_xla/csrc/ops/avg_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/avg_pool_nd_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/avg_pool_nd_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/bernoulli.cpp b/torch_xla/csrc/ops/bernoulli.cpp index 43e59e9bd41..a64c73678af 100644 --- a/torch_xla/csrc/ops/bernoulli.cpp +++ b/torch_xla/csrc/ops/bernoulli.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/bernoulli.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/cast.cpp b/torch_xla/csrc/ops/cast.cpp index a3b814f20b1..56aee6fa7e0 100644 --- a/torch_xla/csrc/ops/cast.cpp +++ b/torch_xla/csrc/ops/cast.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/cast.h" #include "tensorflow/compiler/xla/primitive_util.h" - #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/cat.cpp b/torch_xla/csrc/ops/cat.cpp index e0c8fa1ad35..0c68ce7c8d0 100644 --- a/torch_xla/csrc/ops/cat.cpp +++ b/torch_xla/csrc/ops/cat.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/cat.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/cholesky.cpp b/torch_xla/csrc/ops/cholesky.cpp index 6e3351c47d3..08bfe04d117 100644 --- a/torch_xla/csrc/ops/cholesky.cpp +++ b/torch_xla/csrc/ops/cholesky.cpp @@ -2,7 +2,6 @@ #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/xla_builder.h" - #include "torch_xla/csrc/lowering_context.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/collective_permute.cpp b/torch_xla/csrc/ops/collective_permute.cpp index 9f6c1211d13..84adb421f86 100644 --- a/torch_xla/csrc/ops/collective_permute.cpp +++ b/torch_xla/csrc/ops/collective_permute.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/constant_pad_nd.cpp b/torch_xla/csrc/ops/constant_pad_nd.cpp index 59607d339a0..7f5f69a5999 100644 --- a/torch_xla/csrc/ops/constant_pad_nd.cpp +++ b/torch_xla/csrc/ops/constant_pad_nd.cpp @@ -3,7 +3,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp index fb2f16b2bf0..8fbf83ca975 100644 --- a/torch_xla/csrc/ops/convolution_backward_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_backward_overrideable.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/convolution.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/convolution_overrideable.cpp b/torch_xla/csrc/ops/convolution_overrideable.cpp index bb91c5ea49e..63cdc672bf0 100644 --- a/torch_xla/csrc/ops/convolution_overrideable.cpp +++ b/torch_xla/csrc/ops/convolution_overrideable.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/convolution.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/diagonal.cpp b/torch_xla/csrc/ops/diagonal.cpp index a3f4874691b..37ce0351fa6 100644 --- a/torch_xla/csrc/ops/diagonal.cpp +++ b/torch_xla/csrc/ops/diagonal.cpp @@ -4,7 +4,6 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" diff --git a/torch_xla/csrc/ops/diagonal_view_update.cpp b/torch_xla/csrc/ops/diagonal_view_update.cpp index 7d6b1dea208..d5a5cb45118 100644 --- a/torch_xla/csrc/ops/diagonal_view_update.cpp +++ b/torch_xla/csrc/ops/diagonal_view_update.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/diagonal_view_update.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/discrete_uniform.cpp b/torch_xla/csrc/ops/discrete_uniform.cpp index b0b72f48843..1707433ff36 100644 --- a/torch_xla/csrc/ops/discrete_uniform.cpp +++ b/torch_xla/csrc/ops/discrete_uniform.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/discrete_uniform.h" - #include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/expand.cpp b/torch_xla/csrc/ops/expand.cpp index 63f555306f7..ba816d250ee 100644 --- a/torch_xla/csrc/ops/expand.cpp +++ b/torch_xla/csrc/ops/expand.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/expand.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/exponential.cpp b/torch_xla/csrc/ops/exponential.cpp index d326aa2475a..8238b5a7425 100644 --- a/torch_xla/csrc/ops/exponential.cpp +++ b/torch_xla/csrc/ops/exponential.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/exponential.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/flip.cpp b/torch_xla/csrc/ops/flip.cpp index 0e66bb1421a..089f292ea75 100644 --- a/torch_xla/csrc/ops/flip.cpp +++ b/torch_xla/csrc/ops/flip.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/flip.h" #include "tensorflow/compiler/xla/client/xla_builder.h" - #include "torch_xla/csrc/lowering_context.h" namespace torch_xla { diff --git a/torch_xla/csrc/ops/gather.cpp b/torch_xla/csrc/ops/gather.cpp index 7471b6448da..de977e645c7 100644 --- a/torch_xla/csrc/ops/gather.cpp +++ b/torch_xla/csrc/ops/gather.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/gather.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/generic_slice.cpp b/torch_xla/csrc/ops/generic_slice.cpp index 4a3810337a5..b75ba2a46b3 100644 --- a/torch_xla/csrc/ops/generic_slice.cpp +++ b/torch_xla/csrc/ops/generic_slice.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/generic_slice.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/get_dimensions_size.cpp b/torch_xla/csrc/ops/get_dimensions_size.cpp index 68da5d3d77b..5173908dc28 100644 --- a/torch_xla/csrc/ops/get_dimensions_size.cpp +++ b/torch_xla/csrc/ops/get_dimensions_size.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/hardtanh_backward.cpp b/torch_xla/csrc/ops/hardtanh_backward.cpp index 5239e81d75e..642a64cb692 100644 --- a/torch_xla/csrc/ops/hardtanh_backward.cpp +++ b/torch_xla/csrc/ops/hardtanh_backward.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/hardtanh_backward.h" - #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/scalar.h" diff --git a/torch_xla/csrc/ops/index_get.cpp b/torch_xla/csrc/ops/index_get.cpp index 496d2892e74..ddcf57ad05e 100644 --- a/torch_xla/csrc/ops/index_get.cpp +++ b/torch_xla/csrc/ops/index_get.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/index_get.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/index_put.cpp b/torch_xla/csrc/ops/index_put.cpp index 39c857cc10f..dafaf11e7ef 100644 --- a/torch_xla/csrc/ops/index_put.cpp +++ b/torch_xla/csrc/ops/index_put.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/index_put.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/index_select.cpp b/torch_xla/csrc/ops/index_select.cpp index 62fb975cd3f..ee448725ac1 100644 --- a/torch_xla/csrc/ops/index_select.cpp +++ b/torch_xla/csrc/ops/index_select.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/index_select.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/kth_value.cpp b/torch_xla/csrc/ops/kth_value.cpp index ac2ef1ecd6c..a18d2dec34d 100644 --- a/torch_xla/csrc/ops/kth_value.cpp +++ b/torch_xla/csrc/ops/kth_value.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/kth_value.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/leaky_relu.cpp b/torch_xla/csrc/ops/leaky_relu.cpp index 8ff87a458d4..92eacb38121 100644 --- a/torch_xla/csrc/ops/leaky_relu.cpp +++ b/torch_xla/csrc/ops/leaky_relu.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/leaky_relu.h" - #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/leaky_relu_backward.cpp b/torch_xla/csrc/ops/leaky_relu_backward.cpp index 6ca5e2180b7..f30001c7833 100644 --- a/torch_xla/csrc/ops/leaky_relu_backward.cpp +++ b/torch_xla/csrc/ops/leaky_relu_backward.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/leaky_relu_backward.h" - #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/linear_interpolation.cpp b/torch_xla/csrc/ops/linear_interpolation.cpp index 26e7fc084cc..4ab2a091053 100644 --- a/torch_xla/csrc/ops/linear_interpolation.cpp +++ b/torch_xla/csrc/ops/linear_interpolation.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/linear_interpolation.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/log_softmax.cpp b/torch_xla/csrc/ops/log_softmax.cpp index b540fa6684e..c45caee0f18 100644 --- a/torch_xla/csrc/ops/log_softmax.cpp +++ b/torch_xla/csrc/ops/log_softmax.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/log_softmax.h" - #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/log_softmax_backward.cpp b/torch_xla/csrc/ops/log_softmax_backward.cpp index bfed6d12bb8..20ec9a145c8 100644 --- a/torch_xla/csrc/ops/log_softmax_backward.cpp +++ b/torch_xla/csrc/ops/log_softmax_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/log_softmax_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/logsumexp.cpp b/torch_xla/csrc/ops/logsumexp.cpp index cd34fb20c50..51e2b669b9e 100644 --- a/torch_xla/csrc/ops/logsumexp.cpp +++ b/torch_xla/csrc/ops/logsumexp.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/logsumexp.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/masked_fill.cpp b/torch_xla/csrc/ops/masked_fill.cpp index 2d4822735c0..c14275b03da 100644 --- a/torch_xla/csrc/ops/masked_fill.cpp +++ b/torch_xla/csrc/ops/masked_fill.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/masked_fill.h" #include "tensorflow/compiler/xla/client/lib/constants.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/scalar.h" diff --git a/torch_xla/csrc/ops/masked_scatter.cpp b/torch_xla/csrc/ops/masked_scatter.cpp index 25af6ce9df8..020aa52aa27 100644 --- a/torch_xla/csrc/ops/masked_scatter.cpp +++ b/torch_xla/csrc/ops/masked_scatter.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/masked_scatter.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/masked_select.cpp b/torch_xla/csrc/ops/masked_select.cpp index d458532ee8c..cb680d8ee8c 100644 --- a/torch_xla/csrc/ops/masked_select.cpp +++ b/torch_xla/csrc/ops/masked_select.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/masked_select.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/max_in_dim.cpp b/torch_xla/csrc/ops/max_in_dim.cpp index 3b1b24e995c..a42134fcd65 100644 --- a/torch_xla/csrc/ops/max_in_dim.cpp +++ b/torch_xla/csrc/ops/max_in_dim.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/max_in_dim.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/max_pool_nd.cpp b/torch_xla/csrc/ops/max_pool_nd.cpp index 2d3eb5b6886..7ed350447f3 100644 --- a/torch_xla/csrc/ops/max_pool_nd.cpp +++ b/torch_xla/csrc/ops/max_pool_nd.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/max_pool_nd.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_pool_nd_backward.cpp b/torch_xla/csrc/ops/max_pool_nd_backward.cpp index 4fc1a02d761..19fae3fd114 100644 --- a/torch_xla/csrc/ops/max_pool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_pool_nd_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/max_pool_nd_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_unpool_nd.cpp b/torch_xla/csrc/ops/max_unpool_nd.cpp index 24984325a93..9609f08063a 100644 --- a/torch_xla/csrc/ops/max_unpool_nd.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/max_unpool_nd.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp index c553b7ec6db..5f639d8de7f 100644 --- a/torch_xla/csrc/ops/max_unpool_nd_backward.cpp +++ b/torch_xla/csrc/ops/max_unpool_nd_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/max_unpool_nd_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/pooling.h" diff --git a/torch_xla/csrc/ops/mean.cpp b/torch_xla/csrc/ops/mean.cpp index 13c98ffd3e7..04cc9cfeaef 100644 --- a/torch_xla/csrc/ops/mean.cpp +++ b/torch_xla/csrc/ops/mean.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/mean.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/min_in_dim.cpp b/torch_xla/csrc/ops/min_in_dim.cpp index 5b9b6181dc1..79acdbaa23c 100644 --- a/torch_xla/csrc/ops/min_in_dim.cpp +++ b/torch_xla/csrc/ops/min_in_dim.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/min_in_dim.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/native_batch_norm_backward.cpp b/torch_xla/csrc/ops/native_batch_norm_backward.cpp index 0bdad07b892..5d67164c192 100644 --- a/torch_xla/csrc/ops/native_batch_norm_backward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/native_batch_norm_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/batch_norm.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/native_batch_norm_forward.cpp b/torch_xla/csrc/ops/native_batch_norm_forward.cpp index 1e28d81f877..7d695be2e23 100644 --- a/torch_xla/csrc/ops/native_batch_norm_forward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_forward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/native_batch_norm_forward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/batch_norm.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/nms.cpp b/torch_xla/csrc/ops/nms.cpp index 2d7b2f4d5a8..8d7ce432310 100644 --- a/torch_xla/csrc/ops/nms.cpp +++ b/torch_xla/csrc/ops/nms.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/nms.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/nms_op.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/nonzero.cpp b/torch_xla/csrc/ops/nonzero.cpp index 16a71c777d9..432dcf433a3 100644 --- a/torch_xla/csrc/ops/nonzero.cpp +++ b/torch_xla/csrc/ops/nonzero.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/nonzero.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/tensor_util.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/normal.cpp b/torch_xla/csrc/ops/normal.cpp index b9f55194c44..925613a57e9 100644 --- a/torch_xla/csrc/ops/normal.cpp +++ b/torch_xla/csrc/ops/normal.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/normal.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/random.h" diff --git a/torch_xla/csrc/ops/not_supported.cpp b/torch_xla/csrc/ops/not_supported.cpp index 0aa978c6e4c..823675fcab9 100644 --- a/torch_xla/csrc/ops/not_supported.cpp +++ b/torch_xla/csrc/ops/not_supported.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/not_supported.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/permute.cpp b/torch_xla/csrc/ops/permute.cpp index da88bc3c850..1e14779df33 100644 --- a/torch_xla/csrc/ops/permute.cpp +++ b/torch_xla/csrc/ops/permute.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/permute.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index 10070ee44ba..b20328494b8 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/prod.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/put.cpp b/torch_xla/csrc/ops/put.cpp index afa9fe74cfb..92655c66f0a 100644 --- a/torch_xla/csrc/ops/put.cpp +++ b/torch_xla/csrc/ops/put.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/put.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/qr.cpp b/torch_xla/csrc/ops/qr.cpp index e9e9f39d582..62af51d89d1 100644 --- a/torch_xla/csrc/ops/qr.cpp +++ b/torch_xla/csrc/ops/qr.cpp @@ -3,7 +3,6 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/qr.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/reflection_pad2d.cpp b/torch_xla/csrc/ops/reflection_pad2d.cpp index ffbe02bfa48..ddee4cb6d99 100644 --- a/torch_xla/csrc/ops/reflection_pad2d.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/reflection_pad2d.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp index 120a622f298..8775020f192 100644 --- a/torch_xla/csrc/ops/reflection_pad2d_backward.cpp +++ b/torch_xla/csrc/ops/reflection_pad2d_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/reflection_pad2d_backward.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/repeat.cpp b/torch_xla/csrc/ops/repeat.cpp index 88f41452a68..14b9fd3ce01 100644 --- a/torch_xla/csrc/ops/repeat.cpp +++ b/torch_xla/csrc/ops/repeat.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/repeat.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/replication_pad.cpp b/torch_xla/csrc/ops/replication_pad.cpp index 785cef41d64..dc623e1f8ff 100644 --- a/torch_xla/csrc/ops/replication_pad.cpp +++ b/torch_xla/csrc/ops/replication_pad.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/replication_pad.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/replication_pad_backward.cpp b/torch_xla/csrc/ops/replication_pad_backward.cpp index d5447783d8a..aeb92ad6509 100644 --- a/torch_xla/csrc/ops/replication_pad_backward.cpp +++ b/torch_xla/csrc/ops/replication_pad_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/replication_pad_backward.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/resize.cpp b/torch_xla/csrc/ops/resize.cpp index 93f47d2e079..e523d0cd82b 100644 --- a/torch_xla/csrc/ops/resize.cpp +++ b/torch_xla/csrc/ops/resize.cpp @@ -2,7 +2,6 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/rrelu_with_noise.cpp b/torch_xla/csrc/ops/rrelu_with_noise.cpp index 18c23f24adc..1b01b227e5c 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/rrelu_with_noise.h" - #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp index 33dcb6b8dcd..99f4c232e06 100644 --- a/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp +++ b/torch_xla/csrc/ops/rrelu_with_noise_backward.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/rrelu_with_noise_backward.h" - #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/scalar.h" diff --git a/torch_xla/csrc/ops/scalar.cpp b/torch_xla/csrc/ops/scalar.cpp index f5c11da192f..d3c8e174b74 100644 --- a/torch_xla/csrc/ops/scalar.cpp +++ b/torch_xla/csrc/ops/scalar.cpp @@ -5,7 +5,6 @@ #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/scatter.cpp b/torch_xla/csrc/ops/scatter.cpp index a112a87d70a..8671fae9981 100644 --- a/torch_xla/csrc/ops/scatter.cpp +++ b/torch_xla/csrc/ops/scatter.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/scatter.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/scatter_add.cpp b/torch_xla/csrc/ops/scatter_add.cpp index b5a33bef9ec..6fc07740764 100644 --- a/torch_xla/csrc/ops/scatter_add.cpp +++ b/torch_xla/csrc/ops/scatter_add.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/scatter_add.h" - #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/select.cpp b/torch_xla/csrc/ops/select.cpp index 0c4574d0c1e..9aa4b98a44b 100644 --- a/torch_xla/csrc/ops/select.cpp +++ b/torch_xla/csrc/ops/select.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/select.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/xla_ops.h" diff --git a/torch_xla/csrc/ops/softmax.cpp b/torch_xla/csrc/ops/softmax.cpp index 3aa6b202cf6..0b4157e47dc 100644 --- a/torch_xla/csrc/ops/softmax.cpp +++ b/torch_xla/csrc/ops/softmax.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/softmax.h" - #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/softmax_backward.cpp b/torch_xla/csrc/ops/softmax_backward.cpp index dee56f2359c..ea47f0cefa2 100644 --- a/torch_xla/csrc/ops/softmax_backward.cpp +++ b/torch_xla/csrc/ops/softmax_backward.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/softmax_backward.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/softmax_builder.h" diff --git a/torch_xla/csrc/ops/split.cpp b/torch_xla/csrc/ops/split.cpp index 970f9c146f1..9c4cb581637 100644 --- a/torch_xla/csrc/ops/split.cpp +++ b/torch_xla/csrc/ops/split.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/squeeze.cpp b/torch_xla/csrc/ops/squeeze.cpp index 28fd30eedb8..c25b5be31de 100644 --- a/torch_xla/csrc/ops/squeeze.cpp +++ b/torch_xla/csrc/ops/squeeze.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/squeeze.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/stack.cpp b/torch_xla/csrc/ops/stack.cpp index 798d2c78f38..d9d1b1ea564 100644 --- a/torch_xla/csrc/ops/stack.cpp +++ b/torch_xla/csrc/ops/stack.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/stack.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/std.cpp b/torch_xla/csrc/ops/std.cpp index d223effd760..98d47716439 100644 --- a/torch_xla/csrc/ops/std.cpp +++ b/torch_xla/csrc/ops/std.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/std.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/std_mean.cpp b/torch_xla/csrc/ops/std_mean.cpp index 034039719b9..6ab83526f95 100644 --- a/torch_xla/csrc/ops/std_mean.cpp +++ b/torch_xla/csrc/ops/std_mean.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/std_mean.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/reduction.h" diff --git a/torch_xla/csrc/ops/sum.cpp b/torch_xla/csrc/ops/sum.cpp index aac52824531..050e5e01ac3 100644 --- a/torch_xla/csrc/ops/sum.cpp +++ b/torch_xla/csrc/ops/sum.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/sum.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/convert_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index 472c87297aa..055a561a3c0 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -3,7 +3,6 @@ #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/threshold.cpp b/torch_xla/csrc/ops/threshold.cpp index c6ec0bd3bd2..cec2e47fbb6 100644 --- a/torch_xla/csrc/ops/threshold.cpp +++ b/torch_xla/csrc/ops/threshold.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/threshold.h" - #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/threshold_backward.cpp b/torch_xla/csrc/ops/threshold_backward.cpp index 715993852da..928903c5e4a 100644 --- a/torch_xla/csrc/ops/threshold_backward.cpp +++ b/torch_xla/csrc/ops/threshold_backward.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/threshold_backward.h" - #include "torch_xla/csrc/elementwise.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/topk.cpp b/torch_xla/csrc/ops/topk.cpp index 4bfcbd8dfaa..234c84ccbcc 100644 --- a/torch_xla/csrc/ops/topk.cpp +++ b/torch_xla/csrc/ops/topk.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/topk.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" #include "torch_xla/csrc/xla_lower_util.h" diff --git a/torch_xla/csrc/ops/triangular_solve.cpp b/torch_xla/csrc/ops/triangular_solve.cpp index fd61a52c1db..9282098702e 100644 --- a/torch_xla/csrc/ops/triangular_solve.cpp +++ b/torch_xla/csrc/ops/triangular_solve.cpp @@ -2,7 +2,6 @@ #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/layout_util.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/tril.cpp b/torch_xla/csrc/ops/tril.cpp index 3646936189d..000fb2437fa 100644 --- a/torch_xla/csrc/ops/tril.cpp +++ b/torch_xla/csrc/ops/tril.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/tril.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" diff --git a/torch_xla/csrc/ops/triu.cpp b/torch_xla/csrc/ops/triu.cpp index b0f41bcace2..51997797976 100644 --- a/torch_xla/csrc/ops/triu.cpp +++ b/torch_xla/csrc/ops/triu.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/triu.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/matrix.h" diff --git a/torch_xla/csrc/ops/uniform.cpp b/torch_xla/csrc/ops/uniform.cpp index 6cf9a031ae0..25aa31d7f6e 100644 --- a/torch_xla/csrc/ops/uniform.cpp +++ b/torch_xla/csrc/ops/uniform.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/uniform.h" - #include "tensorflow/compiler/xla/xla_client/xla_util.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/unselect.cpp b/torch_xla/csrc/ops/unselect.cpp index 03173a88108..c926de324b4 100644 --- a/torch_xla/csrc/ops/unselect.cpp +++ b/torch_xla/csrc/ops/unselect.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/unselect.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/unsqueeze.cpp b/torch_xla/csrc/ops/unsqueeze.cpp index 9ceb692495b..7d821f7f7d0 100644 --- a/torch_xla/csrc/ops/unsqueeze.cpp +++ b/torch_xla/csrc/ops/unsqueeze.cpp @@ -1,6 +1,5 @@ #include "torch_xla/csrc/ops/unsqueeze.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/ops/update_slice.cpp b/torch_xla/csrc/ops/update_slice.cpp index a44080be387..a34e7a8d526 100644 --- a/torch_xla/csrc/ops/update_slice.cpp +++ b/torch_xla/csrc/ops/update_slice.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/update_slice.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/upsample_bilinear2d.cpp b/torch_xla/csrc/ops/upsample_bilinear2d.cpp index 191ed780879..5583d67bd9a 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp index ef586c3e21b..407e2c2b4f8 100644 --- a/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_bilinear2d_backward.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/upsample_nearest2d.cpp b/torch_xla/csrc/ops/upsample_nearest2d.cpp index 24de7144bcd..3b117c24b18 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/util.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp index b2fe2589704..e82697c8ac8 100644 --- a/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp +++ b/torch_xla/csrc/ops/upsample_nearest2d_backward.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" - #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/resize_ops.h" diff --git a/torch_xla/csrc/ops/var.cpp b/torch_xla/csrc/ops/var.cpp index fc8954136ad..7f37b22de96 100644 --- a/torch_xla/csrc/ops/var.cpp +++ b/torch_xla/csrc/ops/var.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/var.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/var_mean.cpp b/torch_xla/csrc/ops/var_mean.cpp index 73cc4013675..f1793f5f473 100644 --- a/torch_xla/csrc/ops/var_mean.cpp +++ b/torch_xla/csrc/ops/var_mean.cpp @@ -1,7 +1,6 @@ #include "torch_xla/csrc/ops/var_mean.h" #include "absl/strings/str_join.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" #include "torch_xla/csrc/ops/infer_output_shape.h" diff --git a/torch_xla/csrc/ops/view.cpp b/torch_xla/csrc/ops/view.cpp index 7a3e2e9fba9..c3e3cceb073 100644 --- a/torch_xla/csrc/ops/view.cpp +++ b/torch_xla/csrc/ops/view.cpp @@ -2,7 +2,6 @@ #include "absl/strings/str_join.h" #include "tensorflow/compiler/xla/shape_util.h" - #include "torch_xla/csrc/data_ops.h" #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/lowering_context.h" diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index 0aad199cd03..3ddd152d689 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -6,7 +6,6 @@ #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_client/debug_macros.h" #include "tensorflow/compiler/xla/xla_client/sys_util.h" - #include "torch_xla/csrc/helpers.h" #include "torch_xla/csrc/shape_builder.h" From cbb67010d537a882b83d06afea330648f7f3d831 Mon Sep 17 00:00:00 2001 From: Will Constable Date: Mon, 11 Oct 2021 09:27:53 -0700 Subject: [PATCH 9/9] Delete .torch_pin --- torch_patches/.torch_pin | 1 - 1 file changed, 1 deletion(-) delete mode 100644 torch_patches/.torch_pin diff --git a/torch_patches/.torch_pin b/torch_patches/.torch_pin deleted file mode 100644 index f08d56e5bed..00000000000 --- a/torch_patches/.torch_pin +++ /dev/null @@ -1 +0,0 @@ -#66181