# Zero-collision Hash Tutorial
This example notebook goes through the following topics:
- Why do we need zero-collision hash?
- How to use the zero-collision module in TorchRec?

## Pre-requisite
Before dive into the details, let's import all the necessary packages first. This needs you to [have the latest `torchrec` library installed](https://docs.pytorch.org/torchrec/setup-torchrec.html#installation).

In [None]:
import torch
from torch import nn
from torchrec import (
    EmbeddingCollection,
    EmbeddingConfig,
    JaggedTensor,
    KeyedJaggedTensor,
    KeyedTensor,
)

from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection

from torchrec.modules.mc_modules import (
    DistanceLFU_EvictionPolicy,
    ManagedCollisionCollection,
    MCHManagedCollisionModule,
)

I0611 161033.883 _utils_internal.py:282] NCCL_DEBUG env var is set to None
I0611 161033.885 _utils_internal.py:291] NCCL_DEBUG is WARN from /etc/nccl.conf
I0611 161039.736 pyper_torch_elastic_logging_utils.py:234] initialized PyperTorchElasticEventHandler


## Hash and Zero Collision Hash
In this section, we present the motivation that
- Why do we need to perform hash on incoming features?
- Why do we need to implement zero-collision hash?

Let's first take a look in the question that why do we need to perform hashing for sparse feature inputs in the recommendation model?  
We firstly create an embedding table of 1000 embeddings.

In [None]:
# define the number of embeddings
num_embeddings = 1000
table_config = EmbeddingConfig(
    name="t1",
    embedding_dim=16,
    num_embeddings=1000,
    feature_names=["f1"],
)
ec = EmbeddingCollection(tables=[table_config])

Usually, for each input sparse feature ID, we regard it as the index of the embedding in the embedding table, and fetch the embedding at the corresponding slot in the embedding table. However, while embedding tables is fixed when instantiating the models, the number of sparse features, such as tags of videos, can keep growing. After a while, the ID of a sparse feature can be larger the size of our embedding table.

In [None]:
feature_id = num_embeddings + 1
input_kjt = KeyedJaggedTensor.from_lengths_sync(
    keys=["f1"],
    values=torch.tensor([feature_id]),
    lengths=torch.tensor([1]),
)

At that point, the query will lead to an `index out of range` error.

In [None]:
try:
    feature_embedding = ec(input_kjt)
except IndexError as e:
    print(f"Query the embedding table of size {num_embeddings} with sparse feature ID {input_kjt['f1'].values()}")
    print(f"This query throws an IndexError: {e}")

Query the embedding table of size 1000 with sparse feature ID tensor([1001])
This query throws an IndexError: index out of range in self


To avoid this error from happening, we hash the sparse feature ID to a value within the range of the embedding table size, and use the hashed value as the feature ID to query the embedding table. 

For the purpose of demonstration, we use Python's built-in hash function to hash an integer (which will not change the value) and remap it to the range of `[0, num_embeddings)` by taking the modulo of `num_embeddings`.

In [None]:
def remap(input_jt_value: int, num_embeddings: int):
    input_hash = hash(input_jt_value)
    return input_hash % num_embeddings

Now we can query the embedding table with the remapped id without error.

In [None]:
remapped_id = remap(feature_id, num_embeddings)
remapped_kjt = KeyedJaggedTensor.from_lengths_sync(
    keys=["f1"],
    values=torch.tensor([remapped_id]),
    lengths=torch.tensor([1]),
)
feature_embedding = ec(remapped_kjt)
print(f"Query the embedding table of size {num_embeddings} with remapped sparse feature ID {remapped_id} from original ID {feature_id}")
print(f"This query does not throw an IndexError, and returns the embedding of the remapped ID: {feature_embedding}")

Query the embedding table of size 1000 with remapped sparse feature ID 1 from original ID 1001
This query does not throw an IndexError, and returns the embedding of the remapped ID: {'f1': <torchrec.sparse.jagged_tensor.JaggedTensor object at 0x7fe444183bf0>}


After answering the first question: __Why do we need to perform hash on incoming features?__, now we can answer the second question: __Why do we need to implement zero-collision hash?__

Because we are casting a larger range of values into a small range, there will be some values being remapped to the same index. For example, using our `remap` function, it will give the same remapped id for feature `num_embeddings + 1` and `1`.

In [None]:
feature_id_1 = 1
feature_id_2 = num_embeddings + 1
remapped_feature_id_1 = remap(feature_id_1, num_embeddings)
remapped_feature_id_2 = remap(feature_id_2, num_embeddings)
print(f"feature ID {feature_id_1} is remapped to ID {remapped_feature_id_1}")
print(f"feature ID {feature_id_2} is remapped to ID {remapped_feature_id_2}")
print(f"Check if remapped feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {remapped_feature_id_1 == remapped_feature_id_2}")

feature ID 1 is remapped to ID 1
feature ID 1001 is remapped to ID 1
Check if remapped feature ID 1 and 1 are the same: True


In this case, two totally different features can share the same embedding. The situation when two feature IDs share the same remapped ID is called a **hash collision**.

In [None]:
input_kjt = KeyedJaggedTensor.from_lengths_sync(
    keys=["f1"],
    values=torch.tensor([remapped_feature_id_1, remapped_feature_id_2]),
    lengths=torch.tensor([1, 1]),
)
feature_embeddings = ec(input_kjt)
feature_id_1_embedding = feature_embeddings["f1"].values()[0]
feature_id_2_embedding = feature_embeddings["f1"].values()[1]
print(f"Embedding of feature ID {remapped_feature_id_1} is {feature_id_1_embedding}")
print(f"Embedding of feature ID {remapped_feature_id_2} is {feature_id_2_embedding}")
print(f"Check if the embeddings of feature ID {remapped_feature_id_1} and {remapped_feature_id_2} are the same: {torch.equal(feature_id_1_embedding, feature_id_2_embedding)}")

Embedding of feature ID 1 is tensor([ 0.0232,  0.0075,  0.0281, -0.0195, -0.0301,  0.0033,  0.0303,  0.0294,
         0.0301, -0.0287, -0.0130, -0.0194,  0.0263,  0.0287,  0.0261, -0.0080],
       grad_fn=<SelectBackward0>)
Embedding of feature ID 1 is tensor([ 0.0232,  0.0075,  0.0281, -0.0195, -0.0301,  0.0033,  0.0303,  0.0294,
         0.0301, -0.0287, -0.0130, -0.0194,  0.0263,  0.0287,  0.0261, -0.0080],
       grad_fn=<SelectBackward0>)
Check if the embeddings of feature ID 1 and 1 are the same: True



Making two different (and potentially totally irrelavant) features share the same embedding will cause inaccurate recommendations.
Lukily, for many sparse features, though their range can be larger than the the embedding table size, their IDs are sparsely located on the range.
In some other cases, the embedding table may only receive frequent queries for a subset of the features.
So we can design some __managed collision hash__ modules to avoid the hash collision from happening.

## TorchRec Zero Collision Hash Modules

TorchRec implements managed collision hash strategies such as *sorted zero collision hash* and *multi-probe zero collision hash (MPZCH)*.

They help hash and remap the feature IDs to embedding table indices with (near-)zero collisions.

In the following content we will use the MPZCH module as an example for how to use the zero-collision modules in TorchRec. The name of the MPZCH module is `HashZchManagedCollisionModule`.

Let's assume we have two tables: `table_0` and `table_1`, each with embeddings for `feature_0` and `feature_1`, respectively.

In [None]:
# define the table sizes
num_embeddings_table_0 = 1000
num_embeddings_table_1 = 2000

# create table configs
table_0_config = EmbeddingConfig(
    name="table_0",
    embedding_dim=16,
    num_embeddings=num_embeddings_table_0,
    feature_names=["feature_0"],
)

table_1_config = EmbeddingConfig(
    name="table_1",
    embedding_dim=16,
    num_embeddings=num_embeddings_table_1,
    feature_names=["feature_1"],
)

Before turning the table configurations into embedding table collection, we instantiate our managed collision modules.

The managed collision modules for a collection of embedding tables are intended to format as a dictionary with `{table_name: mc_module_for_the_table}`.

In [None]:
mc_modules = {}

# Instantiate the module, we provide detailed comments on
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') #
input_hash_size = 10000
mc_modules["table_0"] = MCHManagedCollisionModule(
                zch_size=(table_0_config.num_embeddings),
                input_hash_size=input_hash_size,
                device=device,
                eviction_interval=2,
                eviction_policy=DistanceLFU_EvictionPolicy(),
            )
mc_modules["table_1"] = MCHManagedCollisionModule(
                zch_size=(table_1_config.num_embeddings),
                device=device,
                input_hash_size=input_hash_size,
                eviction_interval=1,
                eviction_policy=DistanceLFU_EvictionPolicy(),
            )

For embedding tables with managed collision modules, TorchRec uses a wrapper module `ManagedCollisionEmbeddingCollection` that contains both the embedding table collections and the managed collision modules. Users only need to pass their table configurations and

In [None]:
mc_ec = ManagedCollisionEmbeddingCollection = (
            ManagedCollisionEmbeddingCollection(
                EmbeddingCollection(
                    tables=[
                        table_0_config,
                        table_1_config
                    ],
                    device=device,
                ),
                ManagedCollisionCollection(
                    managed_collision_modules=mc_modules,
                    embedding_configs=[
                        table_0_config,
                        table_1_config
                    ],
                ),
                return_remapped_features=True, # whether to return the remapped feature IDs
            )
        )

The `ManagedCollisionEmbeddingCollection` module will perform remapping and table look-up for the input. Users only need to pass the keyyed jagged tensor queries into the module.

In [None]:
input_kjt = KeyedJaggedTensor.from_lengths_sync(
    keys=["feature_0", "feature_1"],
    values=torch.tensor([1000, 10001, 2000, 20001]),
    lengths=torch.tensor([1, 1, 1, 1]),
)
for feature_name, feature_jt in input_kjt.to_dict().items():
    print(f"feature name: {feature_name}, feature jt: {feature_jt}")
    print(f"feature jt values: {feature_jt.values()}")

feature name: feature_0, feature jt: JaggedTensor({
    [[1000], [10001]]
})

feature jt values: tensor([ 1000, 10001])
feature name: feature_1, feature jt: JaggedTensor({
    [[2000], [20001]]
})

feature jt values: tensor([ 2000, 20001])


In [None]:
output_embeddings, remapped_ids = mc_ec(input_kjt.to(device))
# show output embeddings
for feature_name, feature_embedding in output_embeddings.items():
    print(f"feature name: {feature_name}, feature embedding: {feature_embedding}")
# show remapped ids
for feature_name, feature_jt in remapped_ids.to_dict().items():
    print(f"feature name: {feature_name}, feature jt values: {feature_jt.values()}")

feature name: feature_0, feature embedding: JaggedTensor({
    [[[0.022659072652459145, 0.0053002419881522655, -0.025007368996739388, -0.013145492412149906, -0.031139537692070007, -0.01486812811344862, -0.01133741531521082, 0.0027838051319122314, 0.026786740869283676, -0.010626785457134247, 0.01148480549454689, 0.02036162279546261, 0.013492186553776264, -0.024412740021944046, 0.01599711738526821, -0.02390478551387787]], [[-0.029269251972436905, 0.01744556427001953, 0.024260954931378365, 0.029459983110427856, -0.026435773819684982, -0.0034603318199515343, -0.007642757147550583, -0.02111411839723587, 0.027316255494952202, 0.015309474430978298, 0.03137263283133507, 0.01699884422123432, 0.02302604913711548, -0.015266639180481434, -0.019045181572437286, 0.006964980624616146]]]
})

feature name: feature_1, feature embedding: JaggedTensor({
    [[[0.009506281465291977, 0.012826820835471153, -0.0017535268561914563, -0.0009170559933409095, -0.014913717284798622, 0.0040654330514371395, -0.011355

Now we have a basic example of how to use the managed collision modules in TorchRec. 

We also provide a profiling example to compare the efficiency of sorted ZCH and MPZCH modules. Check the [Readme](Readme.md) file for more details.