Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions torch_xla/csrc/computation.cpp
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
#include "torch_xla/csrc/computation.h"

#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"

namespace torch_xla {

Computation::Computation(std::string name, xla::XlaComputation computation)
: name_(std::move(name)), computation_(std::move(computation)) {
program_shape_ = ConsumeValue(computation_.GetProgramShape());
hash_ = xla::util::MHash(name_, computation_.proto().SerializeAsString());
hash_ = torch::lazy::MHash(name_, computation_.proto().SerializeAsString());
}

} // namespace torch_xla
5 changes: 3 additions & 2 deletions torch_xla/csrc/computation.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/xla_client/types.h"
#include "torch/csrc/lazy/core/hash.h"

namespace torch_xla {

Expand All @@ -19,13 +20,13 @@ class Computation {

const xla::ProgramShape& program_shape() const { return program_shape_; }

const xla::hash_t& hash() const { return hash_; }
const torch::lazy::hash_t& hash() const { return hash_; }

private:
std::string name_;
xla::XlaComputation computation_;
xla::ProgramShape program_shape_;
xla::hash_t hash_;
torch::lazy::hash_t hash_;
};

using ComputationPtr = std::shared_ptr<Computation>;
Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/debug_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/unique.h"
#include "torch/csrc/lazy/core/hash.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir.h"
#include "torch_xla/csrc/ir_dump_util.h"
Expand Down Expand Up @@ -55,7 +56,7 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span<const XLATensor> tensors,
GraphFormat format) {
std::vector<const ir::Node*> root_nodes;
std::vector<ir::Value> root_values;
std::vector<xla::hash_t> root_hashes;
std::vector<torch::lazy::hash_t> root_hashes;
xla::util::Unique<Device> unique_device;
if (indices != nullptr) {
for (auto index : *indices) {
Expand Down Expand Up @@ -91,7 +92,7 @@ std::string DebugUtil::GetTensorsGraphInfo(absl::Span<const XLATensor> tensors,
if (i > 0) {
ss << ", ";
}
ss << xla::util::HexHash(root_hashes[i]);
ss << torch::lazy::HashToString(root_hashes[i]);
}
ss << ")\n";

Expand Down
5 changes: 3 additions & 2 deletions torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <string>

#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch/csrc/lazy/core/hash.h"

namespace torch_xla {

Expand Down Expand Up @@ -36,8 +37,8 @@ struct Device {
}

size_t hash() const {
return xla::util::StdHashCombine(xla::util::GetEnumValue(hw_type),
ordinal + 1);
return torch::lazy::StdHashCombine(xla::util::GetEnumValue(hw_type),
ordinal + 1);
}

DeviceType hw_type = DeviceType::CPU;
Expand Down
38 changes: 19 additions & 19 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@
#include "tensorflow/compiler/xla/xla_client/cache.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch/csrc/lazy/core/hash.h"
#include "torch_xla/csrc/lowering_context.h"

namespace torch_xla {
namespace ir {
namespace {

using ShapeCache =
xla::util::Cache<xla::hash_t, xla::Shape, xla::util::HashReducer>;
xla::util::Cache<torch::lazy::hash_t, xla::Shape, torch::lazy::HashReducer>;

struct ScopeEntry {
std::string name;
Expand Down Expand Up @@ -101,16 +101,16 @@ std::string Use::ToString() const {
}

size_t Output::Hasher::operator()(const Output& output) const {
return xla::util::StdHashCombine(
return torch::lazy::StdHashCombine(
reinterpret_cast<std::ptrdiff_t>(output.node), output.index);
}

const xla::Shape& Output::shape() const { return node->shape(index); }

const xla::Shape& Output::node_shape() const { return node->shape(); }

xla::hash_t Output::hash() const {
return xla::util::HashCombine(node->hash(), index);
torch::lazy::hash_t Output::hash() const {
return torch::lazy::HashCombine(node->hash(), index);
}

std::string Output::ToString() const {
Expand All @@ -123,44 +123,44 @@ const xla::Shape& Value::shape() const { return node->shape(index); }

const xla::Shape& Value::node_shape() const { return node->shape(); }

xla::hash_t Value::hash() const {
return xla::util::HashCombine(node->hash(), index);
torch::lazy::hash_t Value::hash() const {
return torch::lazy::HashCombine(node->hash(), index);
}

OpKind OpKind::Get(const std::string& name) {
return OpKind(c10::Symbol::fromQualString(name));
}

xla::hash_t OpKind::hash() const {
return xla::util::StringHash(op.toQualString());
torch::lazy::hash_t OpKind::hash() const {
return torch::lazy::StringHash(op.toQualString());
}

Node::Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs,
xla::hash_t hash_seed)
torch::lazy::hash_t hash_seed)
: op_(std::move(op)),
num_outputs_(num_outputs),
shape_(std::move(shape)),
node_hash_(xla::util::HashCombine(op_.hash(), hash_seed)),
node_hash_(torch::lazy::HashCombine(op_.hash(), hash_seed)),
hash_(node_hash_) {
metadata_.scope = GetCurrentScope();
metadata_.frame_info = GetFrameInfo();
for (auto& operand : operands) {
AddOperand(operand.node, operand.index);
hash_ = xla::util::HashCombine(hash_, operand.hash());
hash_ = torch::lazy::HashCombine(hash_, operand.hash());
}
}

Node::Node(OpKind op, OpList operands,
const std::function<xla::Shape()>& shape_fn, size_t num_outputs,
xla::hash_t hash_seed)
torch::lazy::hash_t hash_seed)
: Node(std::move(op), operands, xla::Shape(), num_outputs, hash_seed) {
// Forward the constructor to the one above (with empty shape), so we have the
// full hash information, then fetch/compute the real shape.
shape_ = GetOpShape(shape_fn);
}

Node::Node(OpKind op, xla::Shape shape, size_t num_outputs,
xla::hash_t hash_seed)
torch::lazy::hash_t hash_seed)
: op_(std::move(op)),
num_outputs_(num_outputs),
shape_(std::move(shape)),
Expand Down Expand Up @@ -247,11 +247,11 @@ XlaOpVector Node::Lower(LoweringContext* loctx) const {
XLA_ERROR() << "Lowering not implemented for node: " << *this;
}

xla::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape,
xla::hash_t hash_seed) {
xla::hash_t h =
xla::util::HashCombine(op.hash(), xla::util::Hash(shape.ToString()));
return xla::util::HashCombine(h, hash_seed);
torch::lazy::hash_t Node::GetOpHash(OpKind op, const xla::Shape& shape,
torch::lazy::hash_t hash_seed) {
torch::lazy::hash_t h =
torch::lazy::HashCombine(op.hash(), torch::lazy::Hash(shape.ToString()));
return torch::lazy::HashCombine(h, hash_seed);
}

xla::Shape Node::GetOpShape(const std::function<xla::Shape()>& shape_fn) const {
Expand Down
27 changes: 15 additions & 12 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_client/types.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "torch/csrc/lazy/core/hash.h"
#include "torch_xla/csrc/python_util.h"

namespace torch_xla {
Expand Down Expand Up @@ -83,7 +84,7 @@ struct Output {
const xla::Shape& shape() const;
const xla::Shape& node_shape() const;

xla::hash_t hash() const;
torch::lazy::hash_t hash() const;

bool operator==(const Output& rhs) const {
return node == rhs.node && index == rhs.index;
Expand Down Expand Up @@ -120,7 +121,7 @@ struct Value {
const xla::Shape& shape() const;
const xla::Shape& node_shape() const;

xla::hash_t hash() const;
torch::lazy::hash_t hash() const;

operator bool() const { return node != nullptr; }

Expand All @@ -143,7 +144,7 @@ struct OpKind {
return c10::unique_t(op) < c10::unique_t(rhs.op);
}

xla::hash_t hash() const;
torch::lazy::hash_t hash() const;

std::string ToString() const { return op.toQualString(); }

Expand Down Expand Up @@ -174,15 +175,17 @@ class Node {
// for the operation. The num_outputs tells how many outputs a given operation
// generates.
Node(OpKind op, OpList operands, xla::Shape shape, size_t num_outputs = 1,
xla::hash_t hash_seed = 0x5a2d296e9);
torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9);

// Same as the constructor above, but the shape is generated by a function,
// only if needed (shape cache miss).
Node(OpKind op, OpList operands, const std::function<xla::Shape()>& shape_fn,
size_t num_outputs = 1, xla::hash_t hash_seed = 0x5a2d296e9);
size_t num_outputs = 1,
torch::lazy::hash_t hash_seed = (uint32_t)0x5a2d296e9);

// Contructor used to create leaf nodes.
Node(OpKind op, xla::Shape shape, size_t num_outputs, xla::hash_t hash_seed);
Node(OpKind op, xla::Shape shape, size_t num_outputs,
torch::lazy::hash_t hash_seed);

virtual ~Node();

Expand All @@ -204,9 +207,9 @@ class Node {

const std::set<Use>& uses() const { return uses_; }

xla::hash_t node_hash() const { return node_hash_; }
torch::lazy::hash_t node_hash() const { return node_hash_; }

xla::hash_t hash() const { return hash_; }
torch::lazy::hash_t hash() const { return hash_; }

const MetaData& metadata() const { return metadata_; }

Expand Down Expand Up @@ -243,8 +246,8 @@ class Node {

xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;

static xla::hash_t GetOpHash(OpKind op, const xla::Shape& shape,
xla::hash_t hash_seed);
static torch::lazy::hash_t GetOpHash(OpKind op, const xla::Shape& shape,
torch::lazy::hash_t hash_seed);

static std::vector<SourceLocation> GetFrameInfo();

Expand All @@ -260,9 +263,9 @@ class Node {
// We use a set for uses, as we want deterministic use sequencing.
std::set<Use> uses_;
// The hash value of this node.
xla::hash_t node_hash_ = 0;
torch::lazy::hash_t node_hash_ = 0;
// The hash value of the graph rooted at this node.
xla::hash_t hash_ = 0;
torch::lazy::hash_t hash_ = 0;
// The IR specific metadata attached to the IR node.
MetaData metadata_;
// The IR framework user can attach a user defined metadata object deriving
Expand Down
37 changes: 20 additions & 17 deletions torch_xla/csrc/op_by_op_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/metrics.h"
#include "tensorflow/compiler/xla/xla_client/sys_util.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "tensorflow/compiler/xla/xla_client/xla_util.h"
#include "torch/csrc/lazy/core/hash.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/ir_util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/device_data.h"
#include "torch_xla/csrc/tensor_util.h"
#include "torch_xla/csrc/torch_util.h"

namespace torch_xla {
namespace {
Expand All @@ -41,17 +42,17 @@ const xla::Shape& GetParameterShape(const ir::Output& operand,
: xla::ShapeUtil::GetTupleElementShape(input_shape, operand.index);
}

xla::hash_t ComputeNodeKey(const ir::Node* node,
absl::Span<const xla::Shape* const> input_shapes,
const xla::hash_t& seed) {
xla::hash_t key = seed;
torch::lazy::hash_t ComputeNodeKey(
const ir::Node* node, absl::Span<const xla::Shape* const> input_shapes,
const torch::lazy::hash_t& seed) {
torch::lazy::hash_t key = seed;
const auto& operands = node->operands();
for (size_t i = 0; i < operands.size(); ++i) {
key = xla::util::HashCombine(key, xla::util::ShapeHash(GetParameterShape(
operands[i], *input_shapes[i])));
key = torch::lazy::HashCombine(key, torch::lazy::Hash(GetParameterShape(
operands[i], *input_shapes[i])));
}
key = xla::util::HashCombine(key, xla::util::ShapeHash(node->shape()));
return xla::util::HashCombine(key, node->node_hash());
key = torch::lazy::HashCombine(key, torch::lazy::Hash(node->shape()));
return torch::lazy::HashCombine(key, node->node_hash());
}

xla::XlaComputation BuildNodeComputation(
Expand All @@ -71,9 +72,9 @@ xla::XlaComputation BuildNodeComputation(
return ConsumeValue(loctx.Build());
}

xla::hash_t GetNodesKeySeed(const std::string& device,
absl::Span<const std::string> devices) {
return xla::util::MHash(device, devices);
torch::lazy::hash_t GetNodesKeySeed(const std::string& device,
absl::Span<const std::string> devices) {
return torch::lazy::MHash(device, torch::lazy::Hash(devices));
}

} // namespace
Expand Down Expand Up @@ -102,12 +103,14 @@ std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(

auto compilation_devices =
xla::ComputationClient::Get()->GetCompilationDevices(device, devices);
xla::hash_t nodes_key_seed = GetNodesKeySeed(device, compilation_devices);
torch::lazy::hash_t nodes_key_seed =
GetNodesKeySeed(device, compilation_devices);
Device exec_device(device);
std::vector<xla::hash_t> cache_keys;
std::unordered_map<xla::hash_t, std::vector<size_t>, xla::util::HashReducer>
std::vector<torch::lazy::hash_t> cache_keys;
std::unordered_map<torch::lazy::hash_t, std::vector<size_t>,
torch::lazy::HashReducer>
compile_indices;
std::unordered_map<xla::hash_t, size_t, xla::util::HashReducer>
std::unordered_map<torch::lazy::hash_t, size_t, torch::lazy::HashReducer>
cache_keys_instance;
std::list<xla::Shape> compile_shapes;
std::vector<bool> device_data_ops(post_order.size());
Expand All @@ -133,7 +136,7 @@ std::vector<xla::ComputationClient::ExecuteChainedOp> OpByOpExecutor::BuildOps(
op_input_shapes.push_back(ops_shapes[op_index]);
}

xla::hash_t cache_key =
torch::lazy::hash_t cache_key =
ComputeNodeKey(node, op_input_shapes, nodes_key_seed);
cxop.computation = compile_cache_.Get(cache_key);
if (cxop.computation == nullptr) {
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/csrc/op_by_op_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class OpByOpExecutor {

private:
using CompileCache =
xla::util::Cache<xla::hash_t, xla::ComputationClient::Computation,
xla::util::HashReducer>;
xla::util::Cache<torch::lazy::hash_t, xla::ComputationClient::Computation,
torch::lazy::HashReducer>;

explicit OpByOpExecutor(size_t compile_cache_size);

Expand Down
3 changes: 1 addition & 2 deletions torch_xla/csrc/ops/adaptive_avg_pool2d.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "torch_xla/csrc/ops/adaptive_avg_pool2d.h"

#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/infer_output_shape.h"
#include "torch_xla/csrc/pooling.h"
Expand All @@ -27,7 +26,7 @@ AdaptiveAvgPool2d::AdaptiveAvgPool2d(const Value& input,
std::vector<xla::int64> output_size)
: Node(ir::OpKind(at::aten::adaptive_avg_pool2d), {input},
[&]() { return NodeOutputShape(input, output_size); },
/*num_outputs=*/1, xla::util::MHash(output_size)),
/*num_outputs=*/1, torch::lazy::MHash(output_size)),
output_size_(std::move(output_size)) {}

NodePtr AdaptiveAvgPool2d::Clone(OpList operands) const {
Expand Down
Loading