-
Notifications
You must be signed in to change notification settings - Fork 559
Add sharding hash to IR nodes #4286
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -113,15 +113,21 @@ class XlaNode : public torch::lazy::Node { | |
| torch::lazy::hash_t hash() const override { return dag_hash_; } | ||
|
|
||
| torch::lazy::hash_t shapeHash() const override { return dag_hash_; } | ||
|
|
||
| torch::lazy::hash_t shardingHash() const { return sharding_hash_; } | ||
|
|
||
| // The node's outputs get assigned the same HLO sharding | ||
| // TODO: test multi-output example. | ||
| const std::shared_ptr<xla::OpSharding> GetSharding() const { | ||
| return output_sharding_; | ||
| } | ||
| void SetSharding(const xla::OpSharding& sharding) { | ||
| output_sharding_ = std::make_shared<xla::OpSharding>(sharding); | ||
|
|
||
| void SetSharding(const xla::OpSharding& sharding); | ||
|
|
||
| void ClearSharding() { | ||
| output_sharding_ = nullptr; | ||
| sharding_hash_ = 0; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I felt like we should not provide There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In a subsequent PR I'll be removing them so that this is all captured in the constructor |
||
| } | ||
| void ClearSharding() { output_sharding_ = nullptr; } | ||
|
|
||
| private: | ||
| xla::Shape GetOpShape(const std::function<xla::Shape()>& shape_fn) const; | ||
|
|
@@ -132,9 +138,13 @@ class XlaNode : public torch::lazy::Node { | |
|
|
||
| static std::vector<torch::lazy::SourceLocation> GetFrameInfo(); | ||
|
|
||
| static torch::lazy::hash_t CreateShardingHash( | ||
| std::shared_ptr<xla::OpSharding> sharding, torch::lazy::hash_t hash_seed); | ||
|
|
||
| xla::Shape xla_shape_; | ||
| torch::lazy::hash_t node_hash_ = 0; | ||
| torch::lazy::hash_t dag_hash_; | ||
| torch::lazy::hash_t sharding_hash_ = 0; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How does sharding_hash_ get used, is it combined to node_hash_ to represent the final node hash? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the subsequent PR (#4287), we combine it with the dag hash when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should just combined it with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In #4288 I move that logic into the constructor, this was just meant to be an intermediate state. |
||
|
|
||
| // Experimental sharding annotation attached to the IR node. | ||
| // TODO(yeounoh): make sure that view update doesn't reset this. | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.