Skip to content

Add embedding op support for Float8Tensor with PerGroup quantization#4175

Merged
jerryzh168 merged 2 commits intomainfrom
gh/jerryzh168/73/head
Mar 26, 2026
Merged

Add embedding op support for Float8Tensor with PerGroup quantization#4175
jerryzh168 merged 2 commits intomainfrom
gh/jerryzh168/73/head

Conversation

@jerryzh168
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 commented Mar 25, 2026

Stack from ghstack (oldest at bottom):

Summary:

  • Register aten.embedding.default and F.embedding ops for Float8Tensor
  • Implementation dequantizes the weight then calls the original op
  • Two separate registrations with correct arg order for each dispatch path:
    aten.embedding.default uses (weight, indices) and F.embedding uses (indices, weight)
  • Works with all granularities (PerTensor, PerRow, PerGroup)

Test Plan:

  • Added test_fp8_embedding parametrized with PerRow/PerGroup(64) and both dispatch paths (torch_function and aten)
  • Verifies quantize_ succeeds on nn.Embedding, output shape matches, and SQNR > 20
  • Run: python -m pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -xvs -k "test_fp8_embedding"

Summary:

- Add granularity field to Float8WeightOnlyConfig supporting PerTensor, PerRow (default), and PerGroup
- Previously hardcoded to PerRow(); now users can pass e.g. Float8WeightOnlyConfig(granularity=PerGroup(64)) for per-group float8 weight quantization
- Update FP8Granularity type alias to include PerGroup

Test Plan:

- Added PerGroup(64) to existing test_fp8_linear_variants parametrization
- Updated check_weight_scaling helpers to verify PerGroup scale shapes (N, K // group_size)
- Existing weight-only tests now pass granularity to config instead of ignoring it
- Run: python -m pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -xvs -k "Float8WeightOnly"

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
Summary:

- Register `aten.embedding.default` and `F.embedding` ops for Float8Tensor
- Implementation dequantizes the weight then calls the original op
- Two separate registrations with correct arg order for each dispatch path:
  `aten.embedding.default` uses (weight, indices) and `F.embedding` uses (indices, weight)
- Works with all granularities (PerTensor, PerRow, PerGroup)

Test Plan:

- Added `test_fp8_embedding` parametrized with PerRow/PerGroup(64) and both dispatch paths (torch_function and aten)
- Verifies quantize_ succeeds on nn.Embedding, output shape matches, and SQNR > 20
- Run: `python -m pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -xvs -k "test_fp8_embedding"`

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Mar 25, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4175

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8a212ce with merge base 02105d4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 25, 2026
jerryzh168 added a commit that referenced this pull request Mar 25, 2026
Summary:

- Register `aten.embedding.default` and `F.embedding` ops for Float8Tensor
- Implementation dequantizes the weight then calls the original op
- Two separate registrations with correct arg order for each dispatch path:
  `aten.embedding.default` uses (weight, indices) and `F.embedding` uses (indices, weight)
- Works with all granularities (PerTensor, PerRow, PerGroup)

Test Plan:

- Added `test_fp8_embedding` parametrized with PerRow/PerGroup(64) and both dispatch paths (torch_function and aten)
- Verifies quantize_ succeeds on nn.Embedding, output shape matches, and SQNR > 20
- Run: `python -m pytest test/quantization/quantize_/workflows/float8/test_float8_tensor.py -xvs -k "test_fp8_embedding"`

ghstack-source-id: c57981c
Pull Request resolved: #4175
@jerryzh168 jerryzh168 added the module: inference quantize_ api inference flow label Mar 25, 2026
@jerryzh168 jerryzh168 changed the base branch from gh/jerryzh168/73/base to main March 26, 2026 00:16
@jerryzh168 jerryzh168 merged commit 4611835 into main Mar 26, 2026
36 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: inference quantize_ api inference flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants