forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add variable batch size support to TBE training (pytorch#1752)
Summary: Pull Request resolved: pytorch#1752 This diff adds support for variable batch size (or variable length) in split TBE training on GPU (the extension is called "VBE"). VBE is enabled for the following usecase: - split (`SplitTableBatchedEmbeddingBagsCodegen`), and - pooled (`pooling_mode != PoolingMode.NONE`), and - weighted/unweighted, and - rowwise Adagrad optimizer (`optimizer == OptimType.EXACT_ROWWISE_ADAGRAD`) Important note: This feature is enabled for a specific use case in order to keep the binary size of the FBGEMM library within limits. **Usage:** ``` # Initialize TBE as same as previously emb_op = SplitTableBatchedEmbeddingBagsCodegen( embedding_specs=[...], ... # other params ) # batch sizes (one for each FEATURE and each RANK). # Example: num_features = 2, num_ranks = 4 batch_size_per_feature_per_rank = [ [1, 2, 8, 3] # batch sizes for [Rank 0, Rank 1, Rank 2, Rank 3] in Feature 0 [6, 10, 3, 5] # batch sizes for [Rank 0, Rank 1, Rank 2, Rank 3] in Feature 1 ] # Pass a list of batch_size_per_feature_per_rank to forward. # !! Make sure to pass batch_size_per_feature_per_rank as a keyword arg because there can be other keyword args in forward. !! output = emb_op(indices, offsets, batch_size_per_feature_per_rank=batch_size_per_feature_per_rank) ``` **Output format** {F967393126} **Limitation:** `T` and `max_B` have to fit in 32 bits. - We use lower `info_B_num_bits` bits to store `b` (bag ID; `b` < `max_B`). Supported `max_B` = `2^info_B_num_bits` - We use upper `32 - info_B_num_bits` bits to store `t` (table ID; `t` < `T`). Supported `T` = `2^(32 - info_B_num_bits)` Note that we adjust `info_B_num_bits` automatically at runtime based on `max_B` and `T`. If they cannot fit into 32 bits, it will abort. Differential Revision: D42663369 fbshipit-source-id: 7e9acf65e33f57b2ec876a0565ab80cd9e0fd3f8
- Loading branch information
1 parent
01157c7
commit 7784112
Showing
6 changed files
with
556 additions
and
81 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.