In [1]:
import os
import torch
import torchrec
import torch.distributed as dist

In [2]:
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"

dist.init_process_group(backend="nccl")

In [3]:
ebc = torchrec.EmbeddingBagCollection(
    device="meta",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        )
    ]
)

In [4]:
model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device("cuda"))
print(model)
print(model.plan)



DistributedModelParallel(
  (_dmp_wrapped_module): ShardedEmbeddingBagCollection(
    (lookups): 
     GroupedPooledEmbeddingsLookup(
        (_emb_modules): ModuleList(
          (0): BatchedFusedEmbeddingBag(
            (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()
          )
        )
      )
     (_output_dists): 
     TwPooledEmbeddingDist()
    (embedding_bags): ModuleDict(
      (product_table): Module()
      (user_table): Module()
    )
  )
)
module: 

    param     | sharding type | compute kernel | ranks
------------- | ------------- | -------------- | -----
product_table | table_wise    | fused          | [0]  
user_table    | table_wise    | fused          | [0]  

    param     | shard offsets | shard sizes |   placement  
------------- | ------------- | ----------- | -------------
product_table | [0, 0]        | [4096, 64]  | rank:0/cuda:0
user_table    | [0, 0]        | [4096, 64]  | rank:0/cuda:0


In [5]:
product_eb = torch.nn.EmbeddingBag(4096, 64)
product_eb(input=torch.tensor([101, 202, 303]), offsets=torch.tensor([0, 2, 2]))

tensor([[ 0.8913, -0.0430,  0.1290, -0.7120,  1.3198, -1.0814,  0.3503, -0.4333,
          0.5504,  1.0207,  0.0537, -0.4462, -1.1212, -0.1400,  0.2721, -0.5849,
          0.7315,  0.3629, -0.3564,  0.7947, -1.1947,  0.0654,  0.5599,  0.9508,
          0.5029,  0.0540, -0.1085,  0.6997,  0.4432,  0.0243,  0.8940,  1.3384,
          0.7675,  0.4572, -0.1275,  0.1669, -0.5792, -0.8107, -0.9844, -0.5064,
          0.8403,  0.0323,  0.0767, -0.9152,  0.6379,  1.0795,  0.0330,  0.2821,
          0.3235,  0.4667, -0.7850,  0.2016, -1.0983, -0.5974,  0.3752,  0.1714,
          0.2108, -0.1451, -0.5072, -0.0798,  0.1194,  0.4264, -1.0116, -0.9945],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, 

In [6]:
mb = torchrec.KeyedJaggedTensor(
    keys = ["product", "user"],
    values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),
    lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),
)

print(mb.to(torch.device("cpu")))

KeyedJaggedTensor({
    "product": [[101, 202], [], [303]],
    "user": [[404], [505], [606]]
})



In [16]:
pooled_embeddings = model(mb)
print(pooled_embeddings.to(torch.device(("cpu"))))

KeyedTensor({
    "product": [[0.0018351711332798004, -0.01855020970106125, 0.002718959003686905, -0.006389661692082882, -0.019211817532777786, 0.01953590288758278, -0.0005252901464700699, -0.0034213941544294357, -0.001344003714621067, 0.020707810297608376, -0.0001426665112376213, -0.0017040008679032326, 0.026214448735117912, 0.0006860718131065369, -0.01067335344851017, 0.023592425510287285, -0.007505293004214764, -0.015248609706759453, 0.01786927320063114, -0.00908808596432209, 0.01535269059240818, 0.02175009995698929, -0.028214337304234505, -0.011856856755912304, -0.007146292366087437, 0.013943290337920189, -0.006174187175929546, -0.0020071491599082947, -0.01591501757502556, 0.016104411333799362, 0.007327594794332981, 0.004775887355208397, -0.013550353236496449, -0.008402235805988312, -0.0003183335065841675, 0.015592729672789574, -0.01238129660487175, 0.001090196892619133, 0.01412959210574627, 0.007699424400925636, 0.00788839254528284, -0.006513725966215134, 0.0210098996758461, 0.002