# DynamicEmb Quick Start

The primary goal of this notebook is to provide users with a fast introduction to and practical experience with the DynamicEmb API. The notebook walks through the process of creating an HKV embedding table with DynamicEmb and then training it.

## **Installation**
Requirements:
- TorchREC == v0.7

DynamicEmb v0.1 has a dependency on our customized build of TorchREC.

## **Overview**
This tutorial offers a quick start guide to use DynamicEmb in TorchREC, covering the creation of an HKV embedding table and a sequential embedding lookup with both forward and backward operations.

### Torch Setup
We setup our environment with torch.distributed and set our embedding config.

Here, we use one rank corresponding to 1 GPU.

**Bash Commands (Execute in a terminal):**
Before run this notebook , you need set environment variable in your linux env
```
export RANK=0
export WORLD_SIZE=1
export MASTER_ADDR=127.0.0.1
export MASTER_PORT=29500
export LOCAL_WORLD_SIZE=1
```

In [1]:
import warnings
#Filter FBGEMM warning, make notebook clean
warnings.filterwarnings("ignore", message=".*torch.library.impl_abstract.*", category=FutureWarning)
import os
import numpy as np
import torch
import torchrec
import torch.distributed as dist
backend = "nccl"
dist.init_process_group(backend=backend)

local_rank = dist.get_rank() #for one node
world_size = dist.get_world_size()
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
np.random.seed(1024+local_rank)
               
# Define the configuration parameters for the embedding table, 
# including its name, embedding dimension, total number of embeddings, and feature name.
embedding_table_name = "table_0"
embedding_table_dim = 128
total_num_embedding = 1000
embedding_feature_name = "cate_0"
batch_size = 16

### Applying EmbeddingConfig and EmbeddingCollection in TorchREC. 
The conventions for defining embedding tables are unchanged from TorchREC. Users can utilize TorchREC's existing APIs for creation.

In [2]:
eb_configs = [torchrec.EmbeddingConfig(
                name=embedding_table_name,
                embedding_dim=embedding_table_dim,
                num_embeddings=total_num_embedding,
                feature_names=[embedding_feature_name],
            )]

ebc = torchrec.EmbeddingCollection(
        device=torch.device("meta"),
        tables=eb_configs,
    )

# Configuring the DynamicEmb customized planner.

Within DynamicEmb, the entry point for configuring and planning HKV embedding tables is the Customized Planner. DynamicEmb provides `DynamicEmbParameterConstraints`, `DynamicEmbeddingEnumerator`, and `DynamicEmbeddingShardingPlanner` to facilitate the creation of HKV embedding tables. Implemented either through inheritance or wrapping of the relevant TorchREC APIs, these APIs maintain both functional compatibility and a familiar usage pattern for TorchREC users.

The following code is a simple example of apply a planner , this planner can plan for one HKV table. For a more detailed understanding of the DynamicEmb API usage, please refer to the API documentation provided in `DynamicEmb_APIs.md`.

In [None]:
import math
from torchrec.distributed.comm import get_local_size
from torchrec import DataType
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology, ParameterConstraints
from torchrec.distributed.types import (
    ModuleSharder,
    ShardingType,
)
from torchrec.distributed.planner.storage_reservations import (
    HeuristicalStorageReservation,
)
from torchrec.distributed.types import (
    BoundsCheckMode,
)

from dynamicemb.planner import  DynamicEmbParameterConstraints,DynamicEmbeddingShardingPlanner
from dynamicemb.planner import  DynamicEmbeddingEnumerator
from dynamicemb import DynamicEmbInitializerMode, DynamicEmbInitializerArgs, DynamicEmbTableOptions

# use a function warp all the Planner code
def get_planner(device, eb_configs,batch_size):
    
    DATA_TYPE_NUM_BITS: Dict[DataType, int] = {
        DataType.FP32: 32,
        DataType.FP16: 16,
        DataType.BF16: 16,
    }
    
    # For HVK  embedding table , need to calculate how many bytes of embedding vector store in GPU HBM
    # In this case , we will put all the embedding vector into GPU HBM
    eb_config = eb_configs[0]
    dim = eb_config.embedding_dim
    tmp_type = eb_config.data_type

    embedding_type_bytes = DATA_TYPE_NUM_BITS[tmp_type]/8
    emb_num_embeddings = eb_config.num_embeddings
    emb_num_embeddings_next_power_of_2 = 2 ** math.ceil(math.log2(emb_num_embeddings)) # HKV need embedding vector num is power of 2
    total_hbm_need = embedding_type_bytes*dim*emb_num_embeddings_next_power_of_2
    
    hbm_cap = 80 * 1024 * 1024 * 1024 # H100's HBM bytes per GPU
    ddr_cap = 512 * 1024 * 1024 * 1024# Assume a Node have 512GB memory
    intra_host_bw = 450e9 # Nvlink bandwidth
    inter_host_bw = 25e9 # NIC bandwidth
    
    dict_const = {}

    const = DynamicEmbParameterConstraints(
            sharding_types=[
                ShardingType.ROW_WISE.value,
            ],
            enforce_hbm=True,
            bounds_check_mode=BoundsCheckMode.NONE,
            use_dynamicemb=True,# from here , is all the HKV options , default use_dynamicemb is False , if it is False , it will fallback to raw TorchREC ParameterConstraints
            dynamicemb_options = DynamicEmbTableOptions(
                global_hbm_for_values=total_hbm_need,
                initializer_args=DynamicEmbInitializerArgs(
                    mode=DynamicEmbInitializerMode.NORMAL
                ),
            ),
        )

    dict_const[embedding_table_name] = const
    topology=Topology(
            local_world_size=get_local_size(),
            world_size=dist.get_world_size(),
            compute_device=device.type,
            hbm_cap=hbm_cap,
            ddr_cap=ddr_cap,  # For HVK  , if we need to put embedding vector into Host memory , it is important set ddr capacity
            intra_host_bw=intra_host_bw,
            inter_host_bw=inter_host_bw,
        )
        
    # Same usage of  TorchREC's EmbeddingEnumerator
    enumerator = DynamicEmbeddingEnumerator(
                  topology = topology,
                  constraints=dict_const,
                )
    
    # Almost same usage of  TorchREC's EmbeddingShardingPlanner , but we need to input eb_configs, so we can plan every GPU's HKV object.
    return DynamicEmbeddingShardingPlanner(
            eb_configs = eb_configs,
            topology = topology,
            constraints=dict_const,
            batch_size=batch_size,
            enumerator=enumerator,
            storage_reservation=HeuristicalStorageReservation(percentage=0.05),
            debug=True,
        )

planner = get_planner(device, eb_configs, batch_size)

### Get a plan and Use TorchREC's DistributedModelParallel
Now that we have successfully instantiated a DynamicEmb planner, the next step is to configure the optimizer and sharder. Then, using TorchREC's `DistributedModelParallel` function, we wrap the model into a DistributedModelParallel model.

The basic process is identical to TorchREC, except when you want to use TorchREC's `EmbeddingCollectionSharder` function; in that case, please replace it with the `DynamicEmbeddingCollectionSharder` function. `DynamicEmbeddingCollectionSharder` overloads the dedup indexes process within TorchREC's `EmbeddingCollectionSharder` to accommodate HKV.


In [4]:
from fbgemm_gpu.split_embedding_configs import EmbOptimType
from fbgemm_gpu.split_embedding_configs import SparseType

from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType
from torchrec.distributed.model_parallel import DefaultDataParallelWrapper, DistributedModelParallel

from dynamicemb.shard import  DynamicEmbeddingCollectionSharder

#set optimizer args
learning_rate = 0.1
beta1 = 0.9
beta2 = 0.999
weight_decay = 0
eps = 0.001

#Put args into a optimizer kwargs , which is same usage of TorchREC
optimizer_kwargs = {"optimizer":EmbOptimType.ADAM ,
                    "learning_rate": learning_rate,
                    "beta1":beta1,
                    "beta2":beta2,
                    "weight_decay":weight_decay,
                    "eps":eps}

fused_params = {}
fused_params["output_dtype"] = SparseType.FP32
fused_params.update(optimizer_kwargs)

qcomm_codecs_registry = (
        get_qcomm_codecs_registry(
            qcomms_config=QCommsConfig(
                # pyre-ignore
                forward_precision= CommType.FP32,
                # pyre-ignore
                backward_precision= CommType.FP32,
            )
        )
        if backend == "nccl"
        else None
    )    

# Create a sharder , same usage with TorchREC , but need Use DynamicEmb function, because for index_dedup
# DynamicEmb overload this process to fit HKV
sharder = DynamicEmbeddingCollectionSharder(qcomm_codecs_registry=qcomm_codecs_registry,
                                                    fused_params=fused_params, use_index_dedup=True)

#Same usage of TorchREC
plan = planner.collective_plan(ebc, [sharder], dist.GroupMember.WORLD)

data_parallel_wrapper = DefaultDataParallelWrapper(
        allreduce_comm_precision="fp16"
    )

#Same usage of TorchREC
model = DistributedModelParallel(
        module=ebc,
        device=device,
        # pyre-ignore
        sharders=[sharder],
        plan=plan,
        data_parallel_wrapper=data_parallel_wrapper,
    )

print(model)

DistributedModelParallel(
  (_dmp_wrapped_module): ShardedDynamicEmbeddingCollection(
    (lookups): 
     GroupedEmbeddingsLookup(
        (_emb_modules): ModuleList(
          (0): BatchedDynamicEmbedding(
            (_emb_module): BatchedDynamicEmbeddingTables()
          )
        )
      )
     (_output_dists): 
     RwSequenceEmbeddingDist(
        (_dist): SequenceEmbeddingsAllToAll()
      )
    (embeddings): ModuleDict(
      (table_0): Module()
    )
  )
)


## Generate data and begin forward and backward

With the DistributedModelParallel model created, we can now train the embedding lookup. The code below demonstrates this by first generating the training data and then executing the forward and backward processes.

In [5]:
import numpy as np

num_iterations = 10

# This function generate a random indice to lookup
def generate_sparse_feature(feature_num, num_embeddings_list, max_sequence_size,local_batch_size = 50):

    prefix_sums = np.zeros(feature_num, dtype=int)
    for f in range(1, feature_num):
        prefix_sums[f] = prefix_sums[f - 1] + num_embeddings_list[f - 1]

    indices = []
    lengths = []

    for f in range(feature_num):
        unique_indices = np.random.choice(num_embeddings_list[f], size=(local_batch_size, max_sequence_size[f]), replace=True)
        adjusted_indices = unique_indices
        indices.extend(adjusted_indices.flatten())
        lengths.extend([max_sequence_size[f]] * local_batch_size)

    return torchrec.KeyedJaggedTensor(
        keys=[f"cate_{feature_idx}" for feature_idx in range(feature_num)],
        values=torch.tensor(indices, dtype=torch.int64).cuda(),
        lengths=torch.tensor(lengths, dtype=torch.int64).cuda(),
    )

sparse_features = []
for i in range(num_iterations):
        sparse_features.append(generate_sparse_feature(feature_num = 1,
                       num_embeddings_list=[total_num_embedding],
                       max_sequence_size = [10],                              
                       local_batch_size = batch_size // world_size))

for i in range(num_iterations):
    sparse_feature = sparse_features[i]
    ret = model(sparse_feature)

    feature_names = []
    tensors = []
    for k, v in ret.items():
        feature_names.append(k)
        tensors.append(v.values())

    cat_tensor = torch.cat(tensors, dim=1)
    print(f"iter : {i} , cat_tensor = {cat_tensor}")
    loss = cat_tensor.sum()
    loss.backward()

iter : 0 , cat_tensor = tensor([[-0.3054, -0.3133, -0.2628,  ...,  0.3496, -0.1596,  0.5477],
        [-0.9145,  0.8816,  0.0056,  ...,  0.1427,  0.3048, -1.7920],
        [-0.3354,  1.6754, -0.5813,  ..., -1.3018, -0.8106, -0.6762],
        ...,
        [ 2.2891, -1.5476,  0.6556,  ..., -1.6616, -0.3322,  0.5982],
        [-1.3867,  1.6895,  1.3594,  ...,  1.2477, -1.1120, -0.4656],
        [ 1.6933,  0.4228, -0.5796,  ..., -0.6652,  0.1855, -0.2718]],
       device='cuda:0', grad_fn=<CatBackward0>)
iter : 1 , cat_tensor = tensor([[-1.3479,  1.5079, -0.7182,  ..., -0.4937, -0.6770,  0.8667],
        [ 0.6271, -0.2502, -0.6837,  ...,  0.1481,  0.6616, -0.1672],
        [-0.3730, -0.0913,  1.1698,  ..., -0.0586, -0.6082,  1.6253],
        ...,
        [ 0.9045,  0.3579,  1.3050,  ..., -0.2884,  0.2628,  1.4113],
        [-0.6230,  0.9093,  1.6367,  ...,  1.5787, -0.1485, -0.4934],
        [-2.1421,  0.1621, -0.7686,  ..., -1.1568, -0.3575, -1.2740]],
       device='cuda:0', grad_fn=<Cat

## More resources
For more information, please see DynamicEmb's `README.md` and `DynamicEmb_APIs.md` .

If you want to compare raw TorchREC and DynamicEmb , please see benchmark folder's `README.md` , and test the benchmark in your node.