Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions torch_xla/csrc/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,13 @@ torch::lazy::hash_t XlaNode::GetOpHash(torch::lazy::OpKind op,
return torch::lazy::HashCombine(h, hash_seed);
}

void XlaNode::SetSharding(const xla::OpSharding& sharding) {
output_sharding_ = std::make_shared<xla::OpSharding>(sharding);
// TODO(steventk) Once we move this into the cosntructor, we can use the
// hash seed. For now, we'll use node_hash_ as a seed.
sharding_hash_ = CreateShardingHash(output_sharding_, node_hash_);
}

xla::Shape XlaNode::GetOpShape(
const std::function<xla::Shape()>& shape_fn) const {
ShapeCache* shape_cache = GetShapeCache();
Expand All @@ -165,4 +172,42 @@ const xla::Shape& GetXlaShape(const torch::lazy::Value& value) {
return casted->xla_shape(value.index);
}

// The sharding hash is only based on relevant fields from the xla::OpSharding
// object. We skip the field that's irrelevant, which is the layout.
torch::lazy::hash_t XlaNode::CreateShardingHash(
std::shared_ptr<xla::OpSharding> sharding, torch::lazy::hash_t hash_seed) {
torch::lazy::hash_t sharding_hash = hash_seed;
for (const auto& tile_assignment_dimension :
sharding->tile_assignment_dimensions()) {
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)tile_assignment_dimension);
}
for (const auto& tile_assignment_device :
sharding->tile_assignment_devices()) {
sharding_hash = torch::lazy::HashCombine(sharding_hash,
(uint32_t)tile_assignment_device);
}
for (const auto& last_tile_dim : sharding->last_tile_dims()) {
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)last_tile_dim);
}
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)sharding->type());
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)sharding->replicate_on_last_tile_dim());

xla::ShapeProto shape_proto = sharding->tile_shape();
sharding_hash = torch::lazy::HashCombine(
sharding_hash, (uint32_t)shape_proto.element_type());
for (const auto& dim : shape_proto.dimensions()) {
sharding_hash = torch::lazy::HashCombine(sharding_hash, (uint32_t)dim);
}
for (const auto& is_dyn_dim : shape_proto.is_dynamic_dimension()) {
sharding_hash =
torch::lazy::HashCombine(sharding_hash, (uint32_t)is_dyn_dim);
}

return sharding_hash;
}

} // namespace torch_xla
16 changes: 13 additions & 3 deletions torch_xla/csrc/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,21 @@ class XlaNode : public torch::lazy::Node {
torch::lazy::hash_t hash() const override { return dag_hash_; }

torch::lazy::hash_t shapeHash() const override { return dag_hash_; }

torch::lazy::hash_t shardingHash() const { return sharding_hash_; }

// The node's outputs get assigned the same HLO sharding
// TODO: test multi-output example.
const std::shared_ptr<xla::OpSharding> GetSharding() const {
return output_sharding_;
}
void SetSharding(const xla::OpSharding& sharding) {
output_sharding_ = std::make_shared<xla::OpSharding>(sharding);

void SetSharding(const xla::OpSharding& sharding);

void ClearSharding() {
output_sharding_ = nullptr;
sharding_hash_ = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I felt like we should not provide ClearSharding and SetSharding, we need to create new IR node if we want to modify the sharding.

Copy link
Collaborator Author

@steventk-g steventk-g Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a subsequent PR I'll be removing them so that this is all captured in the constructor

}
void ClearSharding() { output_sharding_ = nullptr; }

private:
xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const;
Expand All @@ -132,9 +138,13 @@ class XlaNode : public torch::lazy::Node {

static std::vector<torch::lazy::SourceLocation> GetFrameInfo();

static torch::lazy::hash_t CreateShardingHash(
std::shared_ptr<xla::OpSharding> sharding, torch::lazy::hash_t hash_seed);

xla::Shape xla_shape_;
torch::lazy::hash_t node_hash_ = 0;
torch::lazy::hash_t dag_hash_;
torch::lazy::hash_t sharding_hash_ = 0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does sharding_hash_ get used, is it combined to node_hash_ to represent the final node hash?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the subsequent PR (#4287), we combine it with the dag hash when hash() is called.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should just combined it with dag_hash_ during construction and not maintaining it as a separate state.

Copy link
Collaborator Author

@steventk-g steventk-g Dec 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #4288 I move that logic into the constructor, this was just meant to be an intermediate state.


// Experimental sharding annotation attached to the IR node.
// TODO(yeounoh): make sure that view update doesn't reset this.
Expand Down