-
Couldn't load subscription status.
- Fork 560
Vectorize local shard retrieval #5826
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
b9d8a7c to
8992668
Compare
8d0b5bf to
bad9356
Compare
| XLATensorPtr xtensor = bridge::GetXlaTensor(input); | ||
| m.def( | ||
| "_get_local_shard_replica_and_indices", | ||
| [](const std::vector<at::Tensor>& input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit. rename input to inputs or input_tensors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
8992668 to
fa65980
Compare
bad9356 to
7ecc38d
Compare
fa65980 to
c8f7315
Compare
7ecc38d to
48e9943
Compare
48e9943 to
b673c07
Compare
To capitalize on the improvements in #5824 and #5825, moving tensor shards to CPU should be batched. This change does the following:
_get_local_shardsand_get_local_shard_replica_and_indicesoperate on lists of tensors instead of individual tensors._sharded_cpu_state_dictto use the batched method across all sharded tensors in the state_dict.With all three changes applied, the amount of time spent transferring the state_dict to CPU for a 2B parameter model decreases from >10s to 3.4s, which unblocks training much more quickly.