Reference: https://docs.pytorch.org/tutorials/intermediate/torchrec_intro_tutorial.html

## Embeddings Recap

In [1]:
import torch
from torchrec import JaggedTensor

In [2]:
num_embeddings, embedding_dim = 10, 4

# Initialize embedding table
weights = torch.rand(num_embeddings, embedding_dim)
print("Weights: ", weights)

Weights:  tensor([[0.4448, 0.0612, 0.2466, 0.6927],
        [0.0136, 0.7638, 0.5630, 0.5855],
        [0.1150, 0.9323, 0.9310, 0.5350],
        [0.2111, 0.3665, 0.4667, 0.3583],
        [0.1836, 0.8943, 0.9763, 0.5469],
        [0.1469, 0.5513, 0.7553, 0.8704],
        [0.0461, 0.1049, 0.5266, 0.1406],
        [0.9534, 0.1321, 0.7688, 0.7059],
        [0.4702, 0.0154, 0.3434, 0.2470],
        [0.4574, 0.4110, 0.8286, 0.9001]])


In [5]:
# Pass in pre-generated weights for demonstration purposes
embedding_collection = torch.nn.Embedding(
    num_embeddings, embedding_dim, _weight=weights  # weights are usually random init
)

# takes the mean of index embeddings passed
embedding_bag_collection = torch.nn.EmbeddingBag(
    num_embeddings, embedding_dim, _weight=weights
)

print("Embedding Collection Table: ", embedding_collection.weight)
print("Embedding Bag Collection Table: ", embedding_bag_collection.weight)

Embedding Collection Table:  Parameter containing:
tensor([[0.4448, 0.0612, 0.2466, 0.6927],
        [0.0136, 0.7638, 0.5630, 0.5855],
        [0.1150, 0.9323, 0.9310, 0.5350],
        [0.2111, 0.3665, 0.4667, 0.3583],
        [0.1836, 0.8943, 0.9763, 0.5469],
        [0.1469, 0.5513, 0.7553, 0.8704],
        [0.0461, 0.1049, 0.5266, 0.1406],
        [0.9534, 0.1321, 0.7688, 0.7059],
        [0.4702, 0.0154, 0.3434, 0.2470],
        [0.4574, 0.4110, 0.8286, 0.9001]], requires_grad=True)
Embedding Bag Collection Table:  Parameter containing:
tensor([[0.4448, 0.0612, 0.2466, 0.6927],
        [0.0136, 0.7638, 0.5630, 0.5855],
        [0.1150, 0.9323, 0.9310, 0.5350],
        [0.2111, 0.3665, 0.4667, 0.3583],
        [0.1836, 0.8943, 0.9763, 0.5469],
        [0.1469, 0.5513, 0.7553, 0.8704],
        [0.0461, 0.1049, 0.5266, 0.1406],
        [0.9534, 0.1321, 0.7688, 0.7059],
        [0.4702, 0.0154, 0.3434, 0.2470],
        [0.4574, 0.4110, 0.8286, 0.9001]], requires_grad=True)


In [8]:
# Lookup rows from the embedding tables
ids = torch.tensor([[1, 3]])
print("Input row IDS: ", ids)
embeddings = embedding_collection(ids)
# Print out the embedding lookups
print("Embedding Collection Results: ")
print(embeddings)
print("Shape: ", embeddings.shape)

Input row IDS:  tensor([[1, 3]])
Embedding Collection Results: 
tensor([[[0.0136, 0.7638, 0.5630, 0.5855],
         [0.2111, 0.3665, 0.4667, 0.3583]]], grad_fn=<EmbeddingBackward0>)
Shape:  torch.Size([1, 2, 4])


In [9]:
# nn.EmbeddingBag takes the mean across batch dimension of above result
pooled_embeddings = embedding_bag_collection(ids)
print("Embedding Bag Collection Results: ")
print(pooled_embeddings)
print("Shape: ", pooled_embeddings.shape)

Embedding Bag Collection Results: 
tensor([[0.1124, 0.5651, 0.5148, 0.4719]], grad_fn=<EmbeddingBagBackward0>)
Shape:  torch.Size([1, 4])


In [10]:
# Same as
print("Mean: ", torch.mean(embedding_collection(ids), dim=1))

Mean:  tensor([[0.1124, 0.5651, 0.5148, 0.4719]], grad_fn=<MeanBackward1>)


## TorchRec features

In [11]:
import torchrec

`EmbeddingBagCollection`: represents a group of embedding bags

Example: Create an `EmbeddingBagCollection` (EBC) with two embedding bags, 1 representing products and 1 representing users.

In [12]:
ebc = torchrec.EmbeddingBagCollection(
    device="cpu",
    tables=[
        torchrec.EmbeddingBagConfig(
            name="product_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["product"],
            pooling=torchrec.PoolingType.SUM,
        ),
        torchrec.EmbeddingBagConfig(
            name="user_table",
            embedding_dim=64,
            num_embeddings=4096,
            feature_names=["user"],
            pooling=torchrec.PoolingType.SUM,
        ),
    ],
)
print(ebc.embedding_bags)

ModuleDict(
  (product_table): EmbeddingBag(4096, 64, mode='sum')
  (user_table): EmbeddingBag(4096, 64, mode='sum')
)


Inspect forward method

In [13]:
import inspect
print(inspect.getsource(ebc.forward))

    def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
        """
        Args:
            features (KeyedJaggedTensor): KJT of form [F X B X L].

        Returns:
            KeyedTensor
        """

        pooled_embeddings: List[torch.Tensor] = []

        feature_dict = features.to_dict()
        for i, embedding_bag in enumerate(self.embedding_bags.values()):
            for feature_name in self._feature_names[i]:
                f = feature_dict[feature_name]
                res = embedding_bag(
                    input=f.values(),
                    offsets=f.offsets(),
                    per_sample_weights=f.weights() if self._is_weighted else None,
                ).float()
                pooled_embeddings.append(res)
        data = torch.cat(pooled_embeddings, dim=1)
        return KeyedTensor(
            keys=self._embedding_names,
            values=data,
            length_per_key=self._lengths_per_embedding,
        )



### Input / Output Data Types

In [15]:
# Batch Size 2
# 1 ID in example 1, 2 IDs in example 2
id_list_feature_lengths = torch.tensor([1, 2])
# Values (IDs) tensor: ID 5 is in example 1, ID 7, 1 is in example 2
id_list_feature_values = torch.tensor([5, 7, 1])
# Lengths can be converted to offsets for easy indexing
id_list_feature_offsets = torch.cumsum(id_list_feature_lengths, dim=0)

In [16]:
print("Offsets: ", id_list_feature_offsets)
print("First Batch: ", id_list_feature_values[: id_list_feature_offsets[0]])
print("Second Batch: ", id_list_feature_values[id_list_feature_offsets[0] : id_list_feature_offsets[1]])

Offsets:  tensor([1, 3])
First Batch:  tensor([5])
Second Batch:  tensor([7, 1])


`torchrec` abstraction: `JaggedTensor`

In [20]:
from torchrec import JaggedTensor

jt = JaggedTensor(values=id_list_feature_values, lengths=id_list_feature_lengths)
print("Offsets: ", jt.offsets())
print("List of values: ", jt.to_dense())
print(jt)

Offsets:  tensor([0, 1, 3])
List of values:  [tensor([5]), tensor([7, 1])]
JaggedTensor({
    [[5], [7, 1]]
})



In [33]:
from torchrec import KeyedJaggedTensor
# ``JaggedTensor`` represents IDs for 1 feature, but we have multiple features in an ``EmbeddingBagCollection``
# That's where ``KeyedJaggedTensor`` comes in! ``KeyedJaggedTensor`` is just multiple ``JaggedTensors`` for multiple id_list_feature_offsets
# From before, we have our two features "product" and "user". Let's create ``JaggedTensors`` for both!
product_jt = JaggedTensor(
    values=torch.tensor([1, 2, 3, 1, 5, 7, 9]),
    lengths=torch.tensor([3, 4]),
)
user_jt = JaggedTensor(
    values=torch.tensor([2, 3, 4, 1]),
    lengths=torch.tensor([2, 2]),
)
kjt = KeyedJaggedTensor.from_jt_dict({"product": product_jt, "user": user_jt})

# It basically is a batch size of 2
# Product 1 has 3 IDs, Product 2 has 4 IDs
# User 1 has 2 IDs, User 2 also has 2 IDs
print("Keys: ", kjt.keys())
print("Lengths: ", kjt.lengths())
print("Values: ", kjt.values())
print("to_dict: ", kjt.to_dict())
print(kjt)

Keys:  ['product', 'user']
Lengths:  tensor([3, 4, 2, 2])
Values:  tensor([1, 2, 3, 1, 5, 7, 9, 2, 3, 4, 1])
to_dict:  {'product': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7d80cd55b750>, 'user': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7d8154291fd0>}
KeyedJaggedTensor({
    "product": [[1, 2, 3], [1, 5, 7, 9]],
    "user": [[2, 3], [4, 1]]
})



In [34]:
# Forward pass on the `EmbeddingCollectionBag`
result = ebc(kjt)
result

<torchrec.sparse.jagged_tensor.KeyedTensor at 0x7d8156e76850>

In [35]:
print(result.keys())
print(result.values().shape)  # Remember the embedding bag pools the multiple embeddings into one (mean)
result_dict = result.to_dict()
for key, embedding in result_dict.items():
    print(key, embedding.shape)

['product', 'user']
torch.Size([2, 128])
product torch.Size([2, 64])
user torch.Size([2, 64])


## Distributed Training and Sharding