## **Installation**
Requirements:
- python >= 3.7

We highly recommend CUDA when using torchRec. If using CUDA:
- cuda >= 11.0


In [None]:
# install conda to make installying pytorch with cudatoolkit 11.3 easier. 
!sudo rm Miniconda3-py37_4.9.2-Linux-x86_64.sh Miniconda3-py37_4.9.2-Linux-x86_64.sh.*
!sudo wget https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo chmod +x Miniconda3-py37_4.9.2-Linux-x86_64.sh
!sudo bash ./Miniconda3-py37_4.9.2-Linux-x86_64.sh -b -f -p /usr/local

rm: cannot remove 'Miniconda3-py37_4.9.2-Linux-x86_64.sh.*': No such file or directory
--2022-04-18 20:38:22--  https://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh
Resolving repo.anaconda.com (repo.anaconda.com)... 104.16.130.3, 104.16.131.3, 2606:4700::6810:8303, ...
Connecting to repo.anaconda.com (repo.anaconda.com)|104.16.130.3|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 90040905 (86M) [application/x-sh]
Saving to: ‘Miniconda3-py37_4.9.2-Linux-x86_64.sh’


2022-04-18 20:38:23 (88.1 MB/s) - ‘Miniconda3-py37_4.9.2-Linux-x86_64.sh’ saved [90040905/90040905]

PREFIX=/usr/local
Unpacking payload ...
Collecting package metadata (current_repodata.json): - \ done
Solving environment: / - \ | / - \ failed with initial frozen solve. Retrying with flexible solve.
Solving environment: / - \ | failed with repodata from current_repodata.json, will retry with next repodata source.
Collecting package metadata (repodata.j

In [None]:
# install pytorch with cudatoolkit 11.3
!sudo conda install pytorch cudatoolkit=11.3 -c pytorch-nightly -y

Collecting package metadata (current_repodata.json): - \ | / - \ | / done
Solving environment: \ | / - \ | / done

# All requested packages already installed.



Installing torchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run 

In [None]:
# install torchrec
!pip3 install torchrec-nightly

Defaulting to user installation because normal site-packages is not writeable


In [None]:
!pip3 install multiprocess

Defaulting to user installation because normal site-packages is not writeable
Collecting multiprocess
  Downloading multiprocess-0.70.12.2-py37-none-any.whl (112 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m112.1/112.1 KB[0m [31m3.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting dill>=0.3.4
  Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.9/86.9 KB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: dill, multiprocess
Successfully installed dill-0.3.4 multiprocess-0.70.12.2


The following steps are needed for the Colab runtime to detect the added shared libraries. The runtime searches for shared libraries in /usr/lib, so we copy over the libraries which were installed in /usr/local/lib/. **This is a very necessary step, only in the colab runtime**. 

In [None]:
!sudo cp /usr/local/lib/lib* /usr/lib/

\**Restart your runtime at this point for the newly installed packages to be seen.** Run the step below immediately after restarting so that python knows where to look for packages. **Always run this step after restarting the runtime.**

In [1]:
import sys
sys.path = ['', '/env/python', '/usr/local/lib/python37.zip', '/usr/local/lib/python3.7', '/usr/local/lib/python3.7/lib-dynload', '/usr/local/lib/python3.7/site-packages', './.local/lib/python3.7/site-packages']

## **Overview**
This tutorial will mainly cover the sharding schemes of embedding tables via `EmbeddingPlanner` and `DistributedModelParallel` API and configuring how we would like to shard the tables, as we explore different schemes for embedding tables model parallelism onto multiple GPUs.

### Distributed Setup
Due to the notebook enviroment, we cannot run SPMD program here but we can do multiprocessing inside the notebook to mimic the setup. Users should be responsible for setting up their own SPMD launcher when using Torchrec. 
We setup our environment so that torch distributed based communication backend can work.

In [2]:
import os
import torch
import torchrec

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

### Constructing our embedding model
Here we use TorchRec offering of [`EmbeddingBagCollection`](https://github.com/facebookresearch/torchrec/blob/main/torchrec/modules/embedding_modules.py#L59) to construct our embedding bag model with embedding tables.

Here, we create an EmbeddingBagCollection (EBC) with four embedding bags. We have two types of tables: large tables and small tables differentiated by their row size difference: 4096 vs 1024. Each table is still represented by 64 dimension embedding. 

We configure the `ParameterConstraints` data structure for the tables, which provides hints for the model parallel API to help decide the sharding and placement strategy for the tables.
In TorchRec, we support 
* `table-wise`: place the entire table on one device;
* `row-wise`: shard the table evenly by row dimension and place one shard on each device of the communication world;
* `table-row-wise`: special sharding optimized for intra-host communication for available fast intra-machine device interconnect, e.g. NVLink;
* `column-wise`: shard the table evenly by embedding dimension, and place one shard on each device of the communication world;

Note how we initially allocate the EBC on device "meta". This will tell EBC to not allocate memory yet.

In [10]:
from torchrec.distributed.planner.types import ParameterConstraints
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.types import ShardingType
from typing import Dict

large_table_cnt = 2
small_table_cnt = 2
large_tables=[
  torchrec.EmbeddingBagConfig(
    name="large_table_" + str(i),
    embedding_dim=64,
    num_embeddings=4096,
    feature_names=["large_table_feature_" + str(i)],
    pooling=torchrec.PoolingType.SUM,
  ) for i in range(large_table_cnt)
]
small_tables=[
  torchrec.EmbeddingBagConfig(
    name="small_table_" + str(i),
    embedding_dim=64,
    num_embeddings=1024,
    feature_names=["small_table_feature_" + str(i)],
    pooling=torchrec.PoolingType.SUM,
  ) for i in range(small_table_cnt)
]

def gen_constraints(sharding_type: ShardingType = ShardingType.TABLE_WISE) -> Dict[str, ParameterConstraints]:
  large_table_constraints = {
    "large_table_" + str(i): ParameterConstraints(
      sharding_types=[sharding_type.value],
      compute_kernels=[EmbeddingComputeKernel.BATCHED_FUSED.value],
    ) for i in range(large_table_cnt)
  }
  small_table_constraints = {
    "small_table_" + str(i): ParameterConstraints(
      sharding_types=[sharding_type.value],
      compute_kernels=[EmbeddingComputeKernel.BATCHED_FUSED.value],
    ) for i in range(small_table_cnt)
  }
  constraints = {**large_table_constraints, **small_table_constraints}
  return constraints

In [11]:
ebc = torchrec.EmbeddingBagCollection(
    device="cuda",
    tables=large_tables + small_tables
)

### DistributedModelParallel in multiprocessing
Now, we have a single process execution function for mimicking one rank's work during SPMD execution.

This code will shard the model collectively with other processes and allocate memories accordingly. It first sets up process groups and do embedding table placement using planner and generate sharded model using `DistributedModelParallel`.


In [23]:
import torch.nn as nn
def single_rank_execution(
    rank: int,
    world_size: int,
    constraints: Dict[str, ParameterConstraints],
    module: nn.Module,
    backend: str,
) -> None:
  import os
  import torch
  import torch.nn as nn
  import torch.distributed as dist
  from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
  from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
  from torchrec.distributed.model_parallel import DistributedModelParallel
  from torchrec.distributed.types import ModuleSharder
  from torchrec.distributed.types import ShardingEnv
  from typing import cast

  from typing import Optional
  def init_distributed_single_host(
      rank: int,
      world_size: int,
      backend: str,
      # pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.
  ) -> dist.ProcessGroup:
      os.environ["RANK"] = f"{rank}"
      os.environ["WORLD_SIZE"] = f"{world_size}"
      dist.init_process_group(rank=rank, world_size=world_size, backend=backend)
      return dist.group.WORLD
  if backend == "nccl":
      device = torch.device(f"cuda:{rank}")
      torch.cuda.set_device(device)
  else:
      device = torch.device("cpu")
  topology = Topology(world_size=world_size, compute_device="cuda")
  pg = init_distributed_single_host(rank, world_size, backend)
  planner = EmbeddingShardingPlanner(
    topology=topology,
    constraints=constraints,
  )
  sharders= [cast(ModuleSharder[torch.nn.Module], EmbeddingBagCollectionSharder())]
  plan: ShardingPlan = planner.collective_plan(module, sharders, pg)

  sharded_model = DistributedModelParallel(
      module,
      env=ShardingEnv.from_process_group(pg),
      plan=plan,
      sharders=sharders,
      device=device,
  )
  print(f"rank:{rank},sharding plan: {plan}")
  return sharded_model


### Multiprocessing Execution
Now let's execute the code in multi-processes representing multiple GPU ranks.



In [24]:
import multiprocess
   
def spmd_sharing_simulation(
    sharding_type: ShardingType = ShardingType.TABLE_WISE,
    world_size = 2,
):
  ctx = multiprocess.get_context("forkserver")
  processes = []
  for rank in range(world_size):
      p = ctx.Process(
          target=single_rank_execution,
          args=(
              rank,
              world_size,
              gen_constraints(sharding_type),
              ebc,
              "nccl"
          ),
      )
      p.start()
      processes.append(p)

  for p in processes:
      p.join()
      assert 0 == p.exitcode

### Table Wise Sharding
Now let's execute the code in two processes for 2 GPUs. We can see in the plan print that how our tables are sharded across GPUs. Each node will have one large table and one small which shows our planner tries for load balance for the embedding tables.

In [25]:
spmd_sharing_simulation(ShardingType.TABLE_WISE)

rank:0,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'large_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:0/cuda:0)])), 'small_table_1': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1024, 64], placement=rank:1/cuda:1)]))}}
rank:1,shar

  " and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "
  " and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "


### Explore other sharding modes
We have initially explored what table-wise sharding would look like and how it balances the tables placement. Now we explore sharding modes with finer focus on load balance: row-wise.


In [26]:
spmd_sharing_simulation(ShardingType.ROW_WISE)

rank:1,sharding plan: {'': {'large_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'large_table_1': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[2048, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[2048, 0], shard_sizes=[2048, 64], placement=rank:1/cuda:1)])), 'small_table_0': ParameterSharding(sharding_type='row_wise', compute_kernel='batched_fused', ranks=[0, 1], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[512, 64], placement=rank:0/cuda:0), ShardMetadata(shard_offsets=[512, 0], shard_sizes=[512, 64], placement=rank

  " and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "
  " and ".join(warn_msg) + " are deprecated. nn.Module.state_dict will not accept them in the future. "
