Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add variable batch size support to TBE training #1752

Closed
wants to merge 1 commit into from

Commits on May 31, 2023

  1. Add variable batch size support to TBE training (pytorch#1752)

    Summary:
    Pull Request resolved: pytorch#1752
    
    This diff adds support for variable batch size (or variable length) in
    split TBE training on GPU (the extension is called "VBE").
    
    VBE is enabled for the following usecase:
    - split (`SplitTableBatchedEmbeddingBagsCodegen`), and
    - pooled (`pooling_mode != PoolingMode.NONE`), and
    - weighted/unweighted, and
    - rowwise Adagrad optimizer (`optimizer ==
      OptimType.EXACT_ROWWISE_ADAGRAD`)
    
    Important note: This feature is enabled for a specific use case in
    order to keep the binary size of the FBGEMM library within limits.
    
    **Usage:**
    
    ```
    # Initialize TBE as same as previously
    emb_op = SplitTableBatchedEmbeddingBagsCodegen(
        embedding_specs=[...],
        ... # other params
    )
    
    # batch sizes (one for each FEATURE and each RANK).
    # Example: num_features = 2, num_ranks = 4
    batch_size_per_feature_per_rank = [
        [1,  2, 8, 3] # batch sizes for [Rank 0, Rank 1, Rank 2, Rank 3] in Feature 0
        [6, 10, 3, 5] # batch sizes for [Rank 0, Rank 1, Rank 2, Rank 3] in Feature 1
    ]
    
    # Pass a list of batch_size_per_feature_per_rank to forward.
    # !! Make sure to pass batch_size_per_feature_per_rank as a keyword arg because there can be other keyword args in forward. !!
    output = emb_op(indices, offsets, batch_size_per_feature_per_rank=batch_size_per_feature_per_rank)
    ```
    
    **Output format**
    
    {F982891369}
    
    **Limitation:**
    
    `T` and `max_B` have to fit in 32 bits.
    - We use lower `info_B_num_bits` bits to store `b` (bag ID; `b` < `max_B`).  Supported `max_B` = `2^info_B_num_bits`
    - We use upper `32 - info_B_num_bits` bits to store `t` (table ID; `t` < `T`).  Supported `T` = `2^(32 - info_B_num_bits)`
    
    Note that we adjust `info_B_num_bits` automatically at runtime based on `max_B` and `T`.  If they cannot fit into 32 bits, it will abort.
    
    Reviewed By: jianyuh
    
    Differential Revision: D42663369
    
    fbshipit-source-id: d613b0a9ced838e3ae8b421a1e5a30de8b158e69
    sryap authored and facebook-github-bot committed May 31, 2023
    Configuration menu
    Copy the full SHA
    dbff94e View commit details
    Browse the repository at this point in the history