Reference: https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html

## Embeddings Recap

In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

In [3]:
num_embeddings, embedding_dim = 10, 4

# Initialize embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights: ", weights)

Weights:  tensor([[0.7434, 0.9344, 0.5137, 0.6649],
        [0.6238, 0.9152, 0.2048, 0.6827],
        [0.8833, 0.7587, 0.9536, 0.2427],
        [0.3581, 0.8361, 0.2727, 0.3238],
        [0.1709, 0.1559, 0.4461, 0.7244],
        [0.6119, 0.2545, 0.9343, 0.6560],
        [0.2828, 0.0089, 0.7649, 0.7542],
        [0.2769, 0.1665, 0.6983, 0.4497],
        [0.2426, 0.0923, 0.6919, 0.9160],
        [0.9665, 0.7026, 0.1881, 0.0379]])


In [4]:
# Pass in pre-generated weights for demonstration purposes
embedding_collection = torch.nn.Embedding(
    num_embeddings, embedding_dim, _weight=weights  # weights are usually random init
)

# takes the mean of index embeddings passed
embedding_bag_collection = torch.nn.EmbeddingBag(
    num_embeddings, embedding_dim, _weight=weights
)

print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

Embedding Collection Table:  Parameter containing:
tensor([[0.7434, 0.9344, 0.5137, 0.6649],
        [0.6238, 0.9152, 0.2048, 0.6827],
        [0.8833, 0.7587, 0.9536, 0.2427],
        [0.3581, 0.8361, 0.2727, 0.3238],
        [0.1709, 0.1559, 0.4461, 0.7244],
        [0.6119, 0.2545, 0.9343, 0.6560],
        [0.2828, 0.0089, 0.7649, 0.7542],
        [0.2769, 0.1665, 0.6983, 0.4497],
        [0.2426, 0.0923, 0.6919, 0.9160],
        [0.9665, 0.7026, 0.1881, 0.0379]], requires_grad=True)
Embedding Bag Collection Table:  Parameter containing:
tensor([[0.7434, 0.9344, 0.5137, 0.6649],
        [0.6238, 0.9152, 0.2048, 0.6827],
        [0.8833, 0.7587, 0.9536, 0.2427],
        [0.3581, 0.8361, 0.2727, 0.3238],
        [0.1709, 0.1559, 0.4461, 0.7244],
        [0.6119, 0.2545, 0.9343, 0.6560],
        [0.2828, 0.0089, 0.7649, 0.7542],
        [0.2769, 0.1665, 0.6983, 0.4497],
        [0.2426, 0.0923, 0.6919, 0.9160],
        [0.9665, 0.7026, 0.1881, 0.0379]], requires_grad=True)


In [5]:
# Lookup rows from the embedding tables
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)

Input row IDS:  tensor([[1, 3]])
Embedding Collection Results: 
tensor([[[0.6238, 0.9152, 0.2048, 0.6827],
         [0.3581, 0.8361, 0.2727, 0.3238]]], grad_fn=<EmbeddingBackward0>)
Shape:  torch.Size([1, 2, 4])


In [6]:
# nn.EmbeddingBag takes the mean across batch dimension of above result
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)

Embedding Bag Collection Results: 
tensor([[0.4910, 0.8757, 0.2387, 0.5032]], grad_fn=<EmbeddingBagBackward0>)
Shape:  torch.Size([1, 4])


In [7]:
# Same as
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))

Mean:  tensor([[0.4910, 0.8757, 0.2387, 0.5032]], grad_fn=<MeanBackward1>)


## TorchRec features

In [8]:
import torchrec

`EmbeddingBagCollection`: represents a group of embedding bags

Example: Create an `EmbeddingBagCollection` (EBC) with two embedding bags, 1 representing products and 1 representing users.

In [9]:
ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        ),
    ],
)
print(ebc.embedding_bags)

ModuleDict(
  (product_table): EmbeddingBag(4096, 64, mode='sum')
  (user_table): EmbeddingBag(4096, 64, mode='sum')
)


Inspect forward method

In [10]:
import inspect
print(inspect.getsource(ebc.forward))

    def forward(
        self,
        features: KeyedJaggedTensor,  # can also take TensorDict as input
    ) -> KeyedTensor:
        """
        Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
        and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.

        Args:
            features (KeyedJaggedTensor): Input KJT
        Returns:
            KeyedTensor
        """
        flat_feature_names: List[str] = []
        features = maybe_td_to_kjt(features, None)
        for names in self._feature_names:
            flat_feature_names.extend(names)
        inverse_indices = reorder_inverse_indices(
            inverse_indices=features.inverse_indices_or_none(),
            feature_names=flat_feature_names,
        )
        pooled_embeddings: List[torch.Tensor] = []
        feature_dict = features.to_dict()
        for i, embedding_bag in enumerate(self.embedding_bags.values()):
            for feature_na

### Input / Output Data Types

In [11]:
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
# Lengths can be converted to offsets for easy indexing
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

In [12]:
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print("Second Batch: ", id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]])

Offsets:  tensor([1, 3])
First Batch:  tensor([5])
Second Batch:  tensor([7, 1])


`torchrec` abstraction: `JaggedTensor`

In [13]:
from torchrec import JaggedTensor

jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
print("Offsets: ", jt.offsets())
print("List of values: ", jt.to_dense())
print(jt)

Offsets:  tensor([0, 1, 3])
List of values:  [tensor([5]), tensor([7, 1])]
JaggedTensor({
    [[5], [7, 1]]
})



In [14]:
from torchrec import KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 3, 1, 5, 7, 9]),
    lengths=torch.tensor([3, 4]),
)
user_jt = JaggedTensor(
    values=torch.tensor([2, 3, 4, 1]),
    lengths=torch.tensor([2, 2]),
)
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

# It basically is a batch size of 2
# Product 1 has 3 IDs, Product 2 has 4 IDs
# User 1 has 2 IDs, User 2 also has 2 IDs
print("Keys: ", kjt.keys())
print("Lengths: ", kjt.lengths())
print("Values: ", kjt.values())
print("to_dict: ", kjt.to_dict())
print(kjt)

Keys:  ['product', 'user']
Lengths:  tensor([3, 4, 2, 2])
Values:  tensor([1, 2, 3, 1, 5, 7, 9, 2, 3, 4, 1])
to_dict:  {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x73ded3ca0f50>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x73ded3ca2ad0>}
KeyedJaggedTensor({
    "product": [[1, 2, 3], [1, 5, 7, 9]],
    "user": [[2, 3], [4, 1]]
})



In [15]:
# Forward pass on the `EmbeddingCollectionBag`
result = ebc(kjt)
result

<torchrec.sparse.jagged_tensor.KeyedTensor at 0x73ded379b390>

In [16]:
print(result.keys())
print(result.values().shape)  # Remember the embedding bag pools the multiple embeddings into one (mean)
result_dict = result.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])


## Distributed Training and Sharding

Idea: distribute embeddings

In [17]:
import os
import torch.distributed as dist
# RANK: GPU device, default 0
os.environ["RANK"] = "0"
# How many devices in our "world"
os.environ["WORLD_SIZE"] = "1"
# Localhost (local training)
os.environ["MASTER_ADDR"] = "localhost"
# Port for distributed training
os.environ["MASTER_PORT"] = "29500"

dist.init_process_group(backend="gloo")
print(f"Distributed environment initialized: {dist}")

Distributed environment initialized: <module 'torch.distributed' from '/home/sadelcarpio/code/steam_games_recommender/notebooks/.venv/lib/python3.11/site-packages/torch/distributed/__init__.py'>


**Model parallel**: Distribute embedding tables across multiple GPUs. Splitting up the embeddings $\rightarrow$ **sharding**

Types of sharding: Table-Wise, Column-Wise, Row-Wise

Each TorchRec module has an unsharder and sharder variant. Unsharded for prototyping, sharded for production.

**Data parallel**: Replicate entire model on each GPU, and each GPU takes in a distinct batch of data (gradients are synced on the backward pass)

In [18]:
ebc

EmbeddingBagCollection(
  (embedding_bags): ModuleDict(
    (product_table): EmbeddingBag(4096, 64, mode='sum')
    (user_table): EmbeddingBag(4096, 64, mode='sum')
  )
)

In [19]:
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.types import ShardingEnv

sharder = EmbeddingBagCollectionSharder()
# `ProcessGroup` from torch.distributed
pg = dist.GroupMember.WORLD
assert pg is not None, "ProcessGroup not initialized"

print(f"Process group: {pg}")

Process group: <torch.distributed.distributed_c10d.ProcessGroup object at 0x73df82774770>


**Planner**: Helps determine the best sharding configuration
Involves:
- Assessing memory constraints of hardware
- Estimate compute based on memory fetches (embedding lookups)
- Addresses data specific factors
- Considers other hardware specifics like bandwidth

In [20]:
# 1 GPU anc compute on CUDA device
planner = EmbeddingShardingPlanner(
    topology=Topology(
        world_size=1,
        compute_device="cuda",
    )
)

plan = planner.collective_plan(ebc, [sharder], pg)
print(f"Sharding Plan generated: {plan}")

Sharding Plan generated: module: 

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
product_table | table_wise    | fused          | [0]  
user_table    | table_wise    | fused          | [0]  

    param     | shard offsets | shard sizes |   placement  
------------- | ------------- | ----------- | -------------
product_table | [0, 0]        | [4096, 64]  | rank:0/cuda:0
user_table    | [0, 0]        | [4096, 64]  | rank:0/cuda:0


In [23]:
dir(pg)

['BackendType',
 'CUSTOM',
 'GLOO',
 'MPI',
 'NCCL',
 'UCC',
 'UNDEFINED',
 'XCCL',
 '__class__',
 '__delattr__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '_allgather_base',
 '_backend_id',
 '_device_types',
 '_enable_collectives_timing',
 '_end_coalescing',
 '_get_backend',
 '_get_backend_name',
 '_get_sequence_number_for_group',
 '_has_hooks',
 '_id',
 '_pybind11_conduit_v1_',
 '_reduce_scatter_base',
 '_register_backend',
 '_register_on_completion_hook',
 '_set_default_backend',
 '_set_group_desc',
 '_set_group_name',
 '_set_sequence_number_for_group',
 '_start_coalescing',
 '_wait_for_pending_works',
 'abort',
 'allgather',
 'allgather_coalesced',
 'allgather_into_tensor_coalesced',
 'allreduce',
 '

In [21]:
# After creating the static plan, shard and generate a ShardedEmbeddingBagCollection
env = ShardingEnv.from_process_group(pg)
sharded_ebc = sharder.shard(ebc, plan.plan[""], env, torch.device("cuda"))

print(f"Sharded EBC: {sharded_ebc}")

Sharded EBC: ShardedEmbeddingBagCollection(
  (lookups): 
   GroupedPooledEmbeddingsLookup(
      (_emb_modules): ModuleList(
        (0): BatchedFusedEmbeddingBag(
          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
        )
      )
    )
   (_output_dists): 
   TwPooledEmbeddingDist()
  (embedding_bags): ModuleDict(
    (product_table): Module()
    (user_table): Module()
  )
)




### GPU Training with LazyAwaitable
`LazyAwaitable` type delays calculating some results as long as possible

In [28]:
from typing import List
from torchrec.distributed.types import LazyAwaitable

class ExampleAwaitable(LazyAwaitable[List[torch.Tensor]]):
    def __init__(self, size: List[int]) -> None:
        super().__init__()
        self._size = size
    def _wait_impl(self) -> torch.Tensor:
        return torch.ones(self._size)

In [29]:
awaitable = ExampleAwaitable([3, 2])
awaitable.wait()

tensor([[1., 1.],
        [1., 1.],
        [1., 1.]])

In [30]:
kjt = kjt.to("cuda")
output = sharded_ebc(kjt)
print(output)

<torchrec.distributed.embeddingbag.EmbeddingBagCollectionAwaitable object at 0x73dec4d0c810>


In [34]:
kt = output.wait()
print(type(kt))
print(kt.keys())
print(kt.values().shape)
result_dict = kt.to_dict()
# Same output as unsharded `EmbeddingBagCollection`
for key, embedding in result_dict.items():
    print(key, embedding.shape)

<class 'torchrec.sparse.jagged_tensor.KeyedTensor'>
['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])


Common APIs for distributed training and inference:
* `input_dist`: Handles distributing inputs from GPU to GPU
* `lookups`: Does the actual embedding lookup in an optimized, batched manner using FBGEMM TBE (?)
* `output_dist`: Handles distributing outputs from GPU to GPU

In [35]:
sharded_ebc

ShardedEmbeddingBagCollection(
  (lookups): 
   GroupedPooledEmbeddingsLookup(
      (_emb_modules): ModuleList(
        (0): BatchedFusedEmbeddingBag(
          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
        )
      )
    )
   (_input_dists): 
   TwSparseFeaturesDist(
      (_dist): KJTAllToAll()
    )
   (_output_dists): 
   TwPooledEmbeddingDist(
      (_dist): PooledEmbeddingsAllToAll()
    )
  (embedding_bags): ModuleDict(
    (product_table): Module()
    (user_table): Module()
  )
)

In [36]:
# Distribute input KJTs to all other GPUs and receive KJTs
sharded_ebc._input_dists

[TwSparseFeaturesDist(
   (_dist): KJTAllToAll()
 )]

In [37]:
# Distribute output embeddings to all other GPUs and receive embeddings
sharded_ebc._output_dists

[TwPooledEmbeddingDist(
   (_dist): PooledEmbeddingsAllToAll()
 )]

**FBGEMM** : provides GPU operators (kernels) optimized for performing lookups of embedding tables.

In [38]:
sharded_ebc._lookups

[GroupedPooledEmbeddingsLookup(
   (_emb_modules): ModuleList(
     (0): BatchedFusedEmbeddingBag(
       (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
     )
   )
 )]

### DistributedModelParallel
DMP decides how to shard the model, and shards it. Abstracts all sharding theory above

In [39]:
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
out = model(kjt)
out.wait()



<torchrec.sparse.jagged_tensor.KeyedTensor at 0x73dea2054c10>

In [40]:
model

DistributedModelParallel(
  (_dmp_wrapped_module): ShardedEmbeddingBagCollection(
    (lookups): 
     GroupedPooledEmbeddingsLookup(
        (_emb_modules): ModuleList(
          (0): BatchedFusedEmbeddingBag(
            (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
          )
        )
      )
     (_input_dists): 
     TwSparseFeaturesDist(
        (_dist): KJTAllToAll()
      )
     (_output_dists): 
     TwPooledEmbeddingDist(
        (_dist): PooledEmbeddingsAllToAll()
      )
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
  )
)

### Adding an Optimizer to `EmbeddingBagCollection`

In [42]:
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from torchrec.optim.optimizers import in_backward_optimizer_filter

In [43]:
fused_params = {
    "optimizer": EmbOptimType.EXACT_ROWWISE_ADAGRAD,
    "learning_rate": 0.02,
    "eps": 0.002,
}
# Initialize sharder with `fused_params`
sharded_with_fused_params = EmbeddingBagCollectionSharder(fused_params=fused_params)
sharded_ebc_fused_params = sharded_with_fused_params.shard(
    ebc, plan.plan[""], env, torch.device("cuda")
)
print(f"Original Sharded EBC fused optimizer: {sharded_ebc_fused_params}")
print(f"Sharded EBC with fused parameters fused optimizer: {sharded_ebc_fused_params.fused_optimizer}")
print(f"Type of optimizer: {type(sharded_ebc_fused_params.fused_optimizer)}")

Original Sharded EBC fused optimizer: ShardedEmbeddingBagCollection(
  (lookups): 
   GroupedPooledEmbeddingsLookup(
      (_emb_modules): ModuleList(
        (0): BatchedFusedEmbeddingBag(
          (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
        )
      )
    )
   (_output_dists): 
   TwPooledEmbeddingDist()
  (embedding_bags): ModuleDict(
    (product_table): Module()
    (user_table): Module()
  )
)
Sharded EBC with fused parameters fused optimizer: : EmbeddingFusedOptimizer (
Parameter Group 0
    lr: 0.019999999552965164
)
Type of optimizer: <class 'torchrec.optim.keyed.CombinedOptimizer'>
