diff --git a/torch_xla/csrc/ir.cpp b/torch_xla/csrc/ir.cpp index 91a42097b368..4f6cc07d3dde 100644 --- a/torch_xla/csrc/ir.cpp +++ b/torch_xla/csrc/ir.cpp @@ -150,6 +150,13 @@ torch::lazy::hash_t XlaNode::GetOpHash(torch::lazy::OpKind op, return torch::lazy::HashCombine(h, hash_seed); } +void XlaNode::SetSharding(const xla::OpSharding& sharding) { + output_sharding_ = std::make_shared(sharding); + // TODO(steventk) Once we move this into the cosntructor, we can use the + // hash seed. For now, we'll use node_hash_ as a seed. + sharding_hash_ = CreateShardingHash(output_sharding_, node_hash_); +} + xla::Shape XlaNode::GetOpShape( const std::function& shape_fn) const { ShapeCache* shape_cache = GetShapeCache(); @@ -165,4 +172,42 @@ const xla::Shape& GetXlaShape(const torch::lazy::Value& value) { return casted->xla_shape(value.index); } +// The sharding hash is only based on relevant fields from the xla::OpSharding +// object. We skip the field that's irrelevant, which is the layout. +torch::lazy::hash_t XlaNode::CreateShardingHash( + std::shared_ptr sharding, torch::lazy::hash_t hash_seed) { + torch::lazy::hash_t sharding_hash = hash_seed; + for (const auto& tile_assignment_dimension : + sharding->tile_assignment_dimensions()) { + sharding_hash = torch::lazy::HashCombine( + sharding_hash, (uint32_t)tile_assignment_dimension); + } + for (const auto& tile_assignment_device : + sharding->tile_assignment_devices()) { + sharding_hash = torch::lazy::HashCombine(sharding_hash, + (uint32_t)tile_assignment_device); + } + for (const auto& last_tile_dim : sharding->last_tile_dims()) { + sharding_hash = + torch::lazy::HashCombine(sharding_hash, (uint32_t)last_tile_dim); + } + sharding_hash = + torch::lazy::HashCombine(sharding_hash, (uint32_t)sharding->type()); + sharding_hash = torch::lazy::HashCombine( + sharding_hash, (uint32_t)sharding->replicate_on_last_tile_dim()); + + xla::ShapeProto shape_proto = sharding->tile_shape(); + sharding_hash = torch::lazy::HashCombine( + sharding_hash, (uint32_t)shape_proto.element_type()); + for (const auto& dim : shape_proto.dimensions()) { + sharding_hash = torch::lazy::HashCombine(sharding_hash, (uint32_t)dim); + } + for (const auto& is_dyn_dim : shape_proto.is_dynamic_dimension()) { + sharding_hash = + torch::lazy::HashCombine(sharding_hash, (uint32_t)is_dyn_dim); + } + + return sharding_hash; +} + } // namespace torch_xla diff --git a/torch_xla/csrc/ir.h b/torch_xla/csrc/ir.h index 8805af28a1f3..1c9d510ba4e6 100644 --- a/torch_xla/csrc/ir.h +++ b/torch_xla/csrc/ir.h @@ -113,15 +113,21 @@ class XlaNode : public torch::lazy::Node { torch::lazy::hash_t hash() const override { return dag_hash_; } torch::lazy::hash_t shapeHash() const override { return dag_hash_; } + + torch::lazy::hash_t shardingHash() const { return sharding_hash_; } + // The node's outputs get assigned the same HLO sharding // TODO: test multi-output example. const std::shared_ptr GetSharding() const { return output_sharding_; } - void SetSharding(const xla::OpSharding& sharding) { - output_sharding_ = std::make_shared(sharding); + + void SetSharding(const xla::OpSharding& sharding); + + void ClearSharding() { + output_sharding_ = nullptr; + sharding_hash_ = 0; } - void ClearSharding() { output_sharding_ = nullptr; } private: xla::Shape GetOpShape(const std::function& shape_fn) const; @@ -132,9 +138,13 @@ class XlaNode : public torch::lazy::Node { static std::vector GetFrameInfo(); + static torch::lazy::hash_t CreateShardingHash( + std::shared_ptr sharding, torch::lazy::hash_t hash_seed); + xla::Shape xla_shape_; torch::lazy::hash_t node_hash_ = 0; torch::lazy::hash_t dag_hash_; + torch::lazy::hash_t sharding_hash_ = 0; // Experimental sharding annotation attached to the IR node. // TODO(yeounoh): make sure that view update doesn't reset this.