## **Installation**
Requirements:
- python >= 3.9
- a device 2 GPUs

We highly recommend CUDA when using torchRec. If using CUDA:
- cuda >= 12.0


In [1]:
!pip3 install --pre torch --index-url https://download.pytorch.org/whl/cu121 -U
!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/cu121
!pip3 install torchmetrics
!pip3 install torchrec --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
[0mLooking in indexes: https://download.pytorch.org/whl/cu121
[0mLooking in indexes: https://download.pytorch.org/whl/cu121
[0m

In [2]:
!pip3 install multiprocess

[0m

## **Overview**
This tutorial will mainly cover the sharding schemes of embedding tables via `EmbeddingPlanner` and `DistributedModelParallel` API and explore the benefits of different sharding schemes for the embedding tables by explicitly configuring them.

### Distributed Setup
Due to the notebook enviroment, we cannot run [`SPMD`](https://en.wikipedia.org/wiki/SPMD) program here but we can do multiprocessing inside the notebook to mimic the setup. Users should be responsible for setting up their own [`SPMD`](https://en.wikipedia.org/wiki/SPMD) launcher when using Torchrec.
We setup our environment so that torch distributed based communication backend can work.

In [3]:
import os
import copy
import torch
import torchrec
import multiprocess
from torchrec.distributed.types import ShardingEnv

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "10000"

Below are codes setup for one process at rank `0` and with WORLD_SIZE (number of processes) as `1`.

In a distributed setup, we will repeat the following steps on each process to set up the distributed environment.

In [4]:
import torch.distributed as dist
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rank = 0
world_size = 1
backend = 'nccl' if 'cuda' in device else 'gloo'
os.environ["RANK"] = f"{rank}"
os.environ["WORLD_SIZE"] = f"{world_size}"

### Constructing our embedding model
Here we use TorchRec offering of [`EmbeddingBagCollection`](https://github.com/facebookresearch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L59) to construct our embedding bag model with embedding tables.

Here, we create an EmbeddingBagCollection (EBC) with four embedding bags. We have two types of tables: large tables and small tables differentiated by their row size difference: 4096 vs 1024. Each table is still represented by 64 dimension embedding.

We configure the `ParameterConstraints` data structure for the tables, which provides hints for the model parallel API to help decide the sharding and placement strategy for the tables.
In TorchRec, we support
* `table-wise`: place the entire table on one device;
* `row-wise`: shard the table evenly by row dimension and place one shard on each device of the communication world;
* `column-wise`: shard the table evenly by embedding dimension, and place one shard on each device of the communication world;
* `table-row-wise`: special sharding optimized for intra-host communication for available fast intra-machine device interconnect, e.g. NVLink;
* `data_parallel`: replicate the tables for every device;

Note how we initially allocate the EBC on device "meta". This will tell EBC to not allocate memory yet.

In [5]:
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType, ShardingPlan
from typing import Dict

large_table_cnt = 2
small_table_cnt = 2
large_tables=[
  torchrec.EmbeddingBagConfig(
    name="large_table_" + str(i),
    embedding_dim=64,
    num_embeddings=4096,
    feature_names=["large_table_feature_" + str(i)],
    pooling=torchrec.PoolingType.SUM,
  ) for i in range(large_table_cnt)
]
small_tables=[
  torchrec.EmbeddingBagConfig(
    name="small_table_" + str(i),
    embedding_dim=64,
    num_embeddings=1024,
    feature_names=["small_table_feature_" + str(i)],
    pooling=torchrec.PoolingType.SUM,
  ) for i in range(small_table_cnt)
]

def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
  large_table_constraints = {
    "large_table_" + str(i): ParameterConstraints(
      sharding_types=[sharding_type.value],
    ) for i in range(large_table_cnt)
  }
  small_table_constraints = {
    "small_table_" + str(i): ParameterConstraints(
      sharding_types=[sharding_type.value],
    ) for i in range(small_table_cnt)
  }
  constraints = {**large_table_constraints, **small_table_constraints}
  return constraints

In [6]:
ebc = torchrec.EmbeddingBagCollection(
    device=torch.device(device),
    tables=large_tables + small_tables
)

In [7]:
print(ebc)

EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (large_table_0): EmbeddingBag(4096, 64, mode='sum')
    (large_table_1): EmbeddingBag(4096, 64, mode='sum')
    (small_table_0): EmbeddingBag(1024, 64, mode='sum')
    (small_table_1): EmbeddingBag(1024, 64, mode='sum')
  )
)


For `table-row-wise`, unfortuately we cannot simulate it due to its nature of operating under multi-host setup. We will present a python [`SPMD`](https://en.wikipedia.org/wiki/SPMD) example in the future to train models with `table-row-wise`.


With data parallel, we will repeat the tables for all devices.


### DistributedModelParallel in multiprocessing
If you have access for **2 GPUs**, we can work on multi-GPU multi-process sharding. Though due to the issue in "Spawn"-started multiprocess, the print may not have outputs on certain devices. But you can check if the assertion is passed.

we have a single process execution function for mimicking one rank's work during [`SPMD`](https://en.wikipedia.org/wiki/SPMD) execution.

This code will shard the model collectively with other processes and allocate memories accordingly. It first sets up process groups and do embedding table placement using planner and generate sharded model using `DistributedModelParallel`.


In [8]:
def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: torch.nn.Module,
    backend: str
) -> None:

    import os
    import torch
    import torch.distributed as dist
    from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
    from torchrec.distributed.model_parallel import DistributedModelParallel
    from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
    from torchrec.distributed.types import ModuleSharder, ShardingEnv, ShardingPlan
    from typing import cast

    def init_distributed_single_host(
        rank: int,
        world_size: int,
        backend: str,
        # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
    ) -> dist.ProcessGroup:
        os.environ["RANK"] = f"{rank}"
        os.environ["WORLD_SIZE"] = f"{world_size}"
        dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
        return dist.group.WORLD

    if backend == "nccl":
        device = torch.device(f"cuda:{rank}")
        torch.cuda.set_device(device)
    else:
        device = torch.device("cpu")
    topology = Topology(world_size=world_size, compute_device=device.type)
    pg = init_distributed_single_host(rank, world_size, backend)
    # pg = dist.group.WORLD
    planner = EmbeddingShardingPlanner(
        topology=topology,
        constraints=constraints,
    )
    sharders = [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
    plan: ShardingPlan = planner.collective_plan(module, sharders, pg)

    sharded_model = DistributedModelParallel(
        module,
        env=ShardingEnv.from_process_group(pg),
        plan=plan,
        sharders=sharders,
        device=device,
    )
    print(f"rank:{rank},sharding plan: {plan}")

    return sharded_model


### Multiprocessing Execution


In [9]:
def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size = 2, # Change this world size according to the number of GPUs available on your end.
):
  ctx = multiprocess.get_context("spawn")
  processes = []

  for rank in range(world_size):
      p = ctx.Process(
          target=single_rank_execution,
          args=(
              rank,
              world_size,
              gen_constraints(sharding_type),
              ebc,
              backend,
          ),
      )
      print(f"start for rank: {rank}")
      p.start()
      processes.append(p)

  for p in processes:
      p.join()
      print(f"exit code: {p.exitcode}")
      assert 0 == p.exitcode

Now we can start the multiprocess sharding. There will not be any output in the terminal due to the Spawn method used for generating child processes, but you should see all the assertions will pass and processes exit with code 0.

### Table Wise Sharding
Now let's execute the code in two processes for 2 GPUs. We can see in the plan print that how our tables are sharded across GPUs. Each node will have one large table and one small which shows our planner tries for load balance for the embedding tables. Table-wise is the de-factor go-to sharding schemes for many small-medium size tables for load balancing over the devices.

In [10]:
spmd_sharing_simulation(ShardingType.TABLE_WISE)

start for rank: 0
start for rank: 1
rank:0,sharding plan: module: 

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
large_table_0 | table_wise    | fused          | [0]  
large_table_1 | table_wise    | fused          | [1]  
small_table_0 | table_wise    | fused          | [0]  
small_table_1 | table_wise    | fused          | [1]  

    param     | shard offsets | shard sizes |   placement  
------------- | ------------- | ----------- | -------------
large_table_0 | [0, 0]        | [4096, 64]  | rank:0/cuda:0
large_table_1 | [0, 0]        | [4096, 64]  | rank:1/cuda:1
small_table_0 | [0, 0]        | [1024, 64]  | rank:0/cuda:0
small_table_1 | [0, 0]        | [1024, 64]  | rank:1/cuda:1
rank:1,sharding plan: module: 

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
large_table_0 | table_wise    | fused          | [0]  
large_table_1 | table_wise    | fused       



exit code: 0
exit code: 0


### Explore other sharding modes
We have initially explored what table-wise sharding would look like and how it balances the tables placement. Now we explore sharding modes with finer focus on load balance: row-wise.

Row-wise is specifically addressing large tables which a single device cannot hold due to the memory size increase from large embedding row numbers. It can address the placement of the super large tables in your models. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by row dimension to be distributed onto two GPUs.

If you are on a CPU, the row-wise sharding will not be allowed.


In [11]:
spmd_sharing_simulation(ShardingType.ROW_WISE)

start for rank: 0
start for rank: 1
rank:0,sharding plan: module: 

    param     | sharding type | compute kernel | ranks 
------------- | ------------- | -------------- | ------
large_table_0 | row_wise      | fused          | [0, 1]
large_table_1 | row_wise      | fused          | [0, 1]
small_table_0 | row_wise      | fused          | [0, 1]
small_table_1 | row_wise      | fused          | [0, 1]

    param     | shard offsets | shard sizes |   placement  
------------- | ------------- | ----------- | -------------
large_table_0 | [0, 0]        | [2048, 64]  | rank:0/cuda:0
large_table_0 | [2048, 0]     | [2048, 64]  | rank:1/cuda:1
large_table_1 | [0, 0]        | [2048, 64]  | rank:0/cuda:0
large_table_1 | [2048, 0]     | [2048, 64]  | rank:1/cuda:1
small_table_0 | [0, 0]        | [512, 64]   | rank:0/cuda:0
small_table_0 | [512, 0]      | [512, 64]   | rank:1/cuda:1
small_table_1 | [0, 0]        | [512, 64]   | rank:0/cuda:0
small_table_1 | [512, 0]      | [512, 64]   | rank:1/cu



exit code: 0
exit code: 0


Column-wise on the other hand, address the load imbalance problems for tables with large embedding dimensions. We will split the table vertically. Users can see that in the `shard_sizes` section in the printed plan log, the tables are halved by embedding dimension to be distributed onto two GPUs.


In [12]:
spmd_sharing_simulation(ShardingType.COLUMN_WISE)

start for rank: 0
start for rank: 1
rank:0,sharding plan: module: 

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
large_table_0 | column_wise   | fused          | [0]  
large_table_1 | column_wise   | fused          | [1]  
small_table_0 | column_wise   | fused          | [0]  
small_table_1 | column_wise   | fused          | [1]  

    param     | shard offsets | shard sizes |   placement  
------------- | ------------- | ----------- | -------------
large_table_0 | [0, 0]        | [4096, 64]  | rank:0/cuda:0
large_table_1 | [0, 0]        | [4096, 64]  | rank:1/cuda:1
small_table_0 | [0, 0]        | [1024, 64]  | rank:0/cuda:0
small_table_1 | [0, 0]        | [1024, 64]  | rank:1/cuda:1
rank:1,sharding plan: module: 

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
large_table_0 | column_wise   | fused          | [0]  
large_table_1 | column_wise   | fused       



exit code: 0
exit code: 0


For `table-row-wise`, unfortuately we cannot simulate it due to its nature of operating under multi-host setup. We will present a python [`SPMD`](https://en.wikipedia.org/wiki/SPMD) example in the future to train models with `table-row-wise`.


With data parallel, we will repeat the tables for all devices.


In [13]:
spmd_sharing_simulation(ShardingType.DATA_PARALLEL)

start for rank: 0
start for rank: 1


Sharding Type is data_parallel, caching params will be ignored
Sharding Type is data_parallel, caching params will be ignored
Sharding Type is data_parallel, caching params will be ignored
Sharding Type is data_parallel, caching params will be ignored
Sharding Type is data_parallel, caching params will be ignored
Sharding Type is data_parallel, caching params will be ignored
Sharding Type is data_parallel, caching params will be ignored
Sharding Type is data_parallel, caching params will be ignored


rank:0,sharding plan: module: 

    param     | sharding type | compute kernel | ranks 
------------- | ------------- | -------------- | ------
large_table_0 | data_parallel | dense          | [0, 1]
large_table_1 | data_parallel | dense          | [0, 1]
small_table_0 | data_parallel | dense          | [0, 1]
small_table_1 | data_parallel | dense          | [0, 1]



rank:1,sharding plan: module: 

    param     | sharding type | compute kernel | ranks 
------------- | ------------- | -------------- | ------
large_table_0 | data_parallel | dense          | [0, 1]
large_table_1 | data_parallel | dense          | [0, 1]
small_table_0 | data_parallel | dense          | [0, 1]
small_table_1 | data_parallel | dense          | [0, 1]







exit code: 0
exit code: 0
