## **Automated Planner Tutorial**
The planner attempts to identify the optimal sharding plan by evaluating a series of proposals which are statically analyzed and fed into an integer partitioner.  The planner is able to automatically select near optimal sharding plans for a wide range of hardware setups, allowing users to scale performance seamlessly from their local development environment to large scale production hardware. 

The built-in planner is packaged with TorchRec, which can be directly accessed via `EmbeddingShardingPlanner` API or indirectly via `DistributedModelParallel` API.

It is recommended to use the default planner as is, and most performance improvements are accomplished by modifying the default arguments via the public facing API.

In this tutorial we will explore how to:

- setup a `Topology`, run the planner to generate a `ShardingPlan`, and interpret basic diagnostic information
- use `ParameterConstaints` to select sharding type and provide pooling factors
- trigger `UVM Caching` when model exceeds GPU memory (happens automatically!)

### Initial Setup (if needed)


Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run 

In [None]:
# install torchrec
!pip3 install torchrec-nightly

The following steps are needed for the Colab runtime to detect the added shared libraries. The runtime searches for shared libraries in /usr/lib, so we copy over the libraries which were installed in /usr/local/lib/. **This is a very necessary step, only in the Colab runtime**. 

In [None]:
!sudo cp /usr/local/lib/lib* /usr/lib/

\**Restart your runtime at this point for the newly installed packages to be seen.** Run the step below immediately after restarting so that python knows where to look for packages. **Always run this step after restarting the runtime.**

In [7]:
import sys
sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages', './.local/lib/python3.7/site-packages']

### Constructing Embedding Tables

To showcase the versatility of the planner, we are going start by creating create 2 embedding bags, a small one and a huge one.

In [230]:
import torchrec

small_table = torchrec.EmbeddingBagConfig(
  name="small_table",
  embedding_dim=64,
  num_embeddings=10000000,
  feature_names=["small_table_feature"],
  pooling=torchrec.PoolingType.SUM,
)

huge_table = torchrec.EmbeddingBagConfig(
  name="huge_table",
  embedding_dim=1024,
  num_embeddings=50000000,
  feature_names=["large_table_feature"],
  pooling=torchrec.PoolingType.SUM,
)

###  Running the standalone Planner
TorchRec seperates model sharding into two separate stages:
- Planning stage: Determine "how" to shard the model for a given sharder(s) and a run-time enviornment (`Topology`).  Output is known as the Sharding Plan
- Sharding stage: Use the given sharder(s) to shard the model in accordance to the Sharding Plan.  This needs executed in the run-time environment.

This separation of responsiblities allows the planning step to be executed independent of the run-time environment, and we can leverage that to explore various hypothetical scenarios in this tutorial.

Ok, lets start with the small table, and look at the planner output on a single gpu run-time.

In [193]:
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.embeddingbag import (
    EmbeddingBagCollectionSharder,
)

# Toy model, We will allocate on device "meta", to avoid allocation of data pointer
small_toy_model = torchrec.EmbeddingBagCollection(
    device="meta", 
    tables=[small_table]
)

# Create a run-time environment, 1 gpu
topology_1gpu = Topology(
    world_size=1, 
    compute_device="cuda"
)

planner = EmbeddingShardingPlanner(
    topology=topology_1gpu,
    debug=True,
)

sharding_plan = planner.plan(small_toy_model, sharders=[EmbeddingBagCollectionSharder()])

I0617 093200.096 stats.py:190] ###################################################################################################


I0617 093200.096 stats.py:192] #                                   --- Planner Statistics ---                                    #


I0617 093200.096 stats.py:199] #           --- Evalulated 16 proposal(s), found 16 possible plan(s), ran for 0.02s ---           #


I0617 093200.097 stats.py:202] # ----------------------------------------------------------------------------------------------- #


I0617 093200.098 stats.py:205] #      Rank     HBM (GB)     DDR (GB)     Perf (ms)     Input (MB)     Output (MB)     Shards     #


I0617 093200.098 stats.py:205] #    ------   ----------   ----------   -----------   ------------   -------------   --------     #


I0617 093200.099 stats.py:205] #         0     2.4 (7%)     0.0 (0%)         0.001            0.0            0.12      TW: 1     #


I0617 093200.099 stats.py:207] #                                                                                                 #


I0617 093200.100 stats.py:209] # Input: MB/iteration, Output: MB/iteration, Shards: number of tables                             #


I0617 093200.101 stats.py:211] # HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients             #


I0617 093200.101 stats.py:212] #                                                                                                 #


I0617 093200.102 stats.py:218] # Compute Kernels:                                                                                #


I0617 093200.102 stats.py:220] #   batched_fused: 1                                                                              #


I0617 093200.103 stats.py:223] #                                                                                                 #


I0617 093200.103 stats.py:224] # Parameter Info:                                                                                 #


I0617 093200.104 stats.py:227] #             FQN     Sharding     Compute Kernel     Perf (ms)     Ranks                         #


I0617 093200.104 stats.py:227] #           -----   ----------   ----------------   -----------   -------                         #


I0617 093200.105 stats.py:227] #    .small_table           TW      batched_fused         0.001         0                         #


I0617 093200.105 stats.py:229] ###################################################################################################


The `Planner Statistics` gives a high level summary of sharding plan, some interesting things to note:
- A total of 16 different proposals were evaluated, all of which were valid plans.   The Planner will return the best sharding plan, defined as the plan with the lowest `Perf` value
- In this situation, the table is `Table-wise` (TW) sharded, using `batched_fused` kernel and is estimated to consume 2.4 GB of HBM memory 

Now lets look at what happens when we increase world_size to 2:



In [194]:
# Create a run-time environment, 2 gpu
topology_2gpu = Topology(
    world_size=2, 
    compute_device="cuda"
)

planner = EmbeddingShardingPlanner(
    topology=topology_2gpu,
    debug=True,
)

sharding_plan = planner.plan(small_toy_model, sharders=[EmbeddingBagCollectionSharder()])

I0617 093202.598 stats.py:190] ###################################################################################################


I0617 093202.598 stats.py:192] #                                   --- Planner Statistics ---                                    #


I0617 093202.599 stats.py:199] #           --- Evalulated 16 proposal(s), found 16 possible plan(s), ran for 0.02s ---           #


I0617 093202.600 stats.py:202] # ----------------------------------------------------------------------------------------------- #


I0617 093202.601 stats.py:205] #      Rank     HBM (GB)     DDR (GB)     Perf (ms)     Input (MB)     Output (MB)     Shards     #


I0617 093202.602 stats.py:205] #    ------   ----------   ----------   -----------   ------------   -------------   --------     #


I0617 093202.603 stats.py:205] #         0     1.2 (4%)     0.0 (0%)         0.003            0.0            0.12      RW: 1     #


I0617 093202.603 stats.py:205] #         1     1.2 (4%)     0.0 (0%)         0.003            0.0            0.12      RW: 1     #


I0617 093202.604 stats.py:207] #                                                                                                 #


I0617 093202.605 stats.py:209] # Input: MB/iteration, Output: MB/iteration, Shards: number of tables                             #


I0617 093202.605 stats.py:211] # HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients             #


I0617 093202.606 stats.py:212] #                                                                                                 #


I0617 093202.606 stats.py:218] # Compute Kernels:                                                                                #


I0617 093202.607 stats.py:220] #   batched_fused: 1                                                                              #


I0617 093202.608 stats.py:223] #                                                                                                 #


I0617 093202.608 stats.py:224] # Parameter Info:                                                                                 #


I0617 093202.609 stats.py:227] #             FQN     Sharding     Compute Kernel     Perf (ms)     Ranks                         #


I0617 093202.610 stats.py:227] #           -----   ----------   ----------------   -----------   -------                         #


I0617 093202.610 stats.py:227] #    .small_table           RW      batched_fused         0.005       0-1                         #


I0617 093202.611 stats.py:229] ###################################################################################################


With two GPUs:
- The same parameter is now sharded `Row-wise` (RW).
- Each rank's estimated peak memory usage is 1.2 GB, reflecting the parameter is split evenly between two ranks.

One question at this someone may have is, how do I know this is the best plan?  For example would how does this compare to just using Table-wise.   Fortunately, the planner API supports this by utilizing `ParameterConstraints`.


In [210]:
from torchrec.distributed.planner.types import ParameterConstraints

constraints = {
    "small_table": ParameterConstraints(
        sharding_types=[ShardingType.TABLE_WISE.value]
    ),
}

planner = EmbeddingShardingPlanner(
    topology=topology_2gpu,
    constraints=constraints,
    debug=True,
)
plan = planner.plan(small_toy_model, sharders=[EmbeddingBagCollectionSharder()])

I0617 094107.148 stats.py:190] ###################################################################################################


I0617 094107.148 stats.py:192] #                                   --- Planner Statistics ---                                    #


I0617 094107.149 stats.py:199] #            --- Evalulated 3 proposal(s), found 3 possible plan(s), ran for 0.00s ---            #


I0617 094107.149 stats.py:202] # ----------------------------------------------------------------------------------------------- #


I0617 094107.150 stats.py:205] #      Rank     HBM (GB)     DDR (GB)     Perf (ms)     Input (MB)     Output (MB)     Shards     #


I0617 094107.151 stats.py:205] #    ------   ----------   ----------   -----------   ------------   -------------   --------     #


I0617 094107.151 stats.py:205] #         0     2.4 (7%)     0.0 (0%)         0.003           0.01            0.25      TW: 1     #


I0617 094107.152 stats.py:205] #         1     0.0 (0%)     0.0 (0%)           0.0            0.0             0.0      TW: 0     #


I0617 094107.152 stats.py:207] #                                                                                                 #


I0617 094107.153 stats.py:209] # Input: MB/iteration, Output: MB/iteration, Shards: number of tables                             #


I0617 094107.154 stats.py:211] # HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients             #


I0617 094107.154 stats.py:212] #                                                                                                 #


I0617 094107.155 stats.py:218] # Compute Kernels:                                                                                #


I0617 094107.155 stats.py:220] #   batched_fused: 1                                                                              #


I0617 094107.156 stats.py:223] #                                                                                                 #


I0617 094107.156 stats.py:224] # Parameter Info:                                                                                 #


I0617 094107.157 stats.py:227] #             FQN     Sharding     Compute Kernel     Perf (ms)     Ranks                         #


I0617 094107.157 stats.py:227] #           -----   ----------   ----------------   -----------   -------                         #


I0617 094107.158 stats.py:227] #    .small_table           TW      batched_fused         0.003         0                         #


I0617 094107.158 stats.py:229] ###################################################################################################


Requiring the planner to use Table-wise results in nearly identical `Perf`, with Rank 0 holding the parameter

Another important use of `ParameterConstraints` is to provide the planner with information regarding the pooling factor(s) for each table.  In models with a high number of tables this becomes the dominant factor in estimating model `Perf` and critial to planner's ability to accurately estimate peak HBM memory usage.


In [222]:
constraints = {
    "small_table": ParameterConstraints(
        pooling_factors=[200.0],
    ),
}

planner = EmbeddingShardingPlanner(
    topology=topology_2gpu,
    constraints=constraints,
    debug=True,
)
plan = planner.plan(small_toy_model, sharders=[EmbeddingBagCollectionSharder()])

I0617 095407.058 stats.py:190] ###################################################################################################


I0617 095407.059 stats.py:192] #                                   --- Planner Statistics ---                                    #


I0617 095407.060 stats.py:199] #           --- Evalulated 16 proposal(s), found 16 possible plan(s), ran for 0.02s ---           #


I0617 095407.060 stats.py:202] # ----------------------------------------------------------------------------------------------- #


I0617 095407.061 stats.py:205] #      Rank     HBM (GB)     DDR (GB)     Perf (ms)     Input (MB)     Output (MB)     Shards     #


I0617 095407.061 stats.py:205] #    ------   ----------   ----------   -----------   ------------   -------------   --------     #


I0617 095407.062 stats.py:205] #         0     1.2 (4%)     0.0 (0%)         0.087           0.78            0.12      RW: 1     #


I0617 095407.063 stats.py:205] #         1     1.2 (4%)     0.0 (0%)         0.087           0.78            0.12      RW: 1     #


I0617 095407.063 stats.py:207] #                                                                                                 #


I0617 095407.064 stats.py:209] # Input: MB/iteration, Output: MB/iteration, Shards: number of tables                             #


I0617 095407.065 stats.py:211] # HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients             #


I0617 095407.066 stats.py:212] #                                                                                                 #


I0617 095407.067 stats.py:218] # Compute Kernels:                                                                                #


I0617 095407.067 stats.py:220] #   batched_fused: 1                                                                              #


I0617 095407.068 stats.py:223] #                                                                                                 #


I0617 095407.069 stats.py:224] # Parameter Info:                                                                                 #


I0617 095407.069 stats.py:227] #             FQN     Sharding     Compute Kernel     Perf (ms)     Ranks                         #


I0617 095407.070 stats.py:227] #           -----   ----------   ----------------   -----------   -------                         #


I0617 095407.071 stats.py:227] #    .small_table           RW      batched_fused         0.173       0-1                         #


I0617 095407.072 stats.py:229] ###################################################################################################


In this case, we can see by providing a pooling factor of 200.0, the planner has `Pref` has increased to 0.087 ms from 0.003 ms.

Finally, let's look quickly at what happens when we have a huge table that will not fit in HBM memory.

In [231]:
from torchrec.distributed.types import ShardingType


# Toy model
huge_toy_model = torchrec.EmbeddingBagCollection(device="meta", tables=[huge_table])

planner = EmbeddingShardingPlanner(
    topology=topology_2gpu,
    debug=True,
)
sharding_plan = planner.plan(huge_toy_model, sharders=[EmbeddingBagCollectionSharder()])

I0617 095651.613 stats.py:190] ###################################################################################################


I0617 095651.614 stats.py:192] #                                   --- Planner Statistics ---                                    #


I0617 095651.614 stats.py:199] #           --- Evalulated 16 proposal(s), found 8 possible plan(s), ran for 0.02s ---            #


I0617 095651.615 stats.py:202] # ----------------------------------------------------------------------------------------------- #


I0617 095651.617 stats.py:205] #      Rank     HBM (GB)     DDR (GB)     Perf (ms)     Input (MB)     Output (MB)     Shards     #


I0617 095651.617 stats.py:205] #    ------   ----------   ----------   -----------   ------------   -------------   --------     #


I0617 095651.619 stats.py:205] #         0   19.1 (60%)   95.4 (75%)         0.543           0.03             2.0      CW: 4     #


I0617 095651.619 stats.py:205] #         1   19.1 (60%)   95.4 (75%)         0.543           0.03             2.0      CW: 4     #


I0617 095651.620 stats.py:207] #                                                                                                 #


I0617 095651.621 stats.py:209] # Input: MB/iteration, Output: MB/iteration, Shards: number of tables                             #


I0617 095651.622 stats.py:211] # HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients             #


I0617 095651.622 stats.py:212] #                                                                                                 #


I0617 095651.623 stats.py:218] # Compute Kernels:                                                                                #


I0617 095651.624 stats.py:220] #   batched_fused_uvm_caching: 1                                                                  #


I0617 095651.625 stats.py:223] #                                                                                                 #


I0617 095651.626 stats.py:224] # Parameter Info:                                                                                 #


I0617 095651.627 stats.py:227] #            FQN     Sharding              Compute Kernel     Perf (ms)             Ranks         #


I0617 095651.628 stats.py:227] #          -----   ----------            ----------------   -----------           -------         #


I0617 095651.628 stats.py:227] #    .huge_table           CW   batched_fused_uvm_caching         1.086   0,0,0,0,1,1,1,1         #


I0617 095651.630 stats.py:229] ###################################################################################################


The planner recognizes that table cannot fit in HBM device memory, and selects `UVM_CACHING` to accomdate the large parameter.  Also, planner had determined `Column-wise` (CW) sharding is optimal in this situation.

I hope this tutorial has given you a brief introduction into the planner.  More detailed functionality is outlined in the planner API documentation.


