-
Notifications
You must be signed in to change notification settings - Fork 561
Add utility functions for distributed checkpointing #5128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
5872316 to
d27834a
Compare
|
cc @yashs97 |
d27834a to
318074b
Compare
| // Clamp the end of the slice to the tensor shape to accurately reflect | ||
| // Clamp the slice bounds to the tensor shape to accurately reflect | ||
| // the shard size without padding. | ||
| int start = std::min(n_j * shard_shape[j], tensor_shape[j]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you catch a bug or this is a more like a safeguard?
If n_j * shard_shape[j] is going to be > tensor_shape[j], it means (n_j + 1) * shard_shape[j] is certainly going to be larger than tensor_shape[j]. Therefore, for that scenario, start will be equal to end and equals to tensor_shape[j]. And that slice seems meaningless.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would call this a latent bug, but it wasn't breaking anything because torch indexing handles negative-length indices as though they were empty. It just breaks the expectation that stop - start reflects the size of the unpadded shard, which we rely on in distributed checkpointing.
You're right - these index slices will end up empty, but this is the desired outcome when the shard consists entirely of padding.
318074b to
cfcd622
Compare
alanwaketan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
| return ShardingType::TUPLE; | ||
| case xla::OpSharding::OTHER: | ||
| // OTHER sharding can indicate either PARTIAL or TILED sharding. | ||
| return sharding.replicate_on_last_tile_dim() ? ShardingType::PARTIAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems pretty hacky. But I guess we don't have other ways round?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My understanding is that we distinguish partial replication as a different sharding type whereas XLA treats partial and tiled as the same type OTHER. @yeounoh could you confirm?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is actually the correct way, the compiler treats TILED and PARTIAL the same as the OTHER type. The differences between the two would be how the tile shards are assigned to difference devices.
|
@jonb377 @alanwaketan is this pr ready to merge? |
|
Yes, I'll merge after TPU CI finishes |
This change adds a few utility functions to support distributed checkpointing. The following changes are included:
sharding_typetoXLAShardedTensorto get the ShardingTypewrap_if_shardedto converttorch.TensorintoXLAShardedTensorif the underlying data is sharded.devicesparameter from_get_local_shard_indicesand instead always return the shard indices in the order of the shardsstartbound of the index slices to the tensor's size.unpadded_dataproperty ofXLAShard