-
Notifications
You must be signed in to change notification settings - Fork 585
Open
Description
Using quantized embeddings with the float32 data type may lead to Floating point exception (core dumped),We can reproduce this using the following command: python test_quant.py,and use the enviroment torchrec==1.1.0+cu124, torch==2.6.0+cu124, fbgemm-gpu==1.1.0+cu124
test_quant.py
import torch
import torchrec
from torch import nn
from torchrec import EmbeddingBagCollection
from torchrec.optim.optimizers import in_backward_optimizer_filter
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.inference.modules import quantize_embeddings
large_table_cnt = 2
small_table_cnt = 2
large_tables = [
torchrec.EmbeddingBagConfig(
name="large_table_" + str(i),
embedding_dim=64,
num_embeddings=4096,
feature_names=["large_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(large_table_cnt)
]
small_tables = [
torchrec.EmbeddingBagConfig(
name="small_table_" + str(i),
embedding_dim=64,
num_embeddings=1024,
feature_names=["small_table_feature_" + str(i)],
pooling=torchrec.PoolingType.SUM,
)
for i in range(small_table_cnt)
]
class DebugModel(nn.Module):
def __init__(self, device: torch.device):
super().__init__()
self.ebc = EmbeddingBagCollection(tables=large_tables + small_tables, device=device)
self.linear = nn.Linear(64 * (small_table_cnt + large_table_cnt), 1)
def forward(self, kjt: KeyedJaggedTensor):
emb = self.ebc(kjt)
return torch.mean(self.linear(emb.values()))
model = DebugModel(device=torch.device("cuda:0"))
# dtype == qint8 is ok
quantize_embeddings(model, dtype=torch.float, inplace=True)Metadata
Metadata
Assignees
Labels
No labels