Skip to content

Use fixed scale for Float8 softmax quantization instead of observer#4260

Merged
jerryzh168 merged 4 commits intomainfrom
gh/jerryzh168/83/head
Apr 11, 2026
Merged

Use fixed scale for Float8 softmax quantization instead of observer#4260
jerryzh168 merged 4 commits intomainfrom
gh/jerryzh168/83/head

Conversation

@jerryzh168
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 commented Apr 9, 2026

Stack from ghstack (oldest at bottom):

Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of finfo(float8_dtype).max / 1.0 (448.0 for float8_e4m3fn). This
simplifies Float8ObservedSoftmax to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

[ghstack-poisoned]
@jerryzh168 jerryzh168 requested a review from vkuzo as a code owner April 9, 2026 23:24
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Apr 9, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4260

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

⏳ No Failures, 3 Pending

As of commit 2265e5a with merge base c554b1f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Apr 9, 2026
Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

ghstack-source-id: 87f4197
Pull Request resolved: #4260
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 9, 2026
@jerryzh168 jerryzh168 added the module: not user facing Use this tag if you don't want this PR to show up in release notes label Apr 9, 2026
… observer"

Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 9, 2026
Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

ghstack-source-id: ff5b439
Pull Request resolved: #4260
… observer"

Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 9, 2026
Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

ghstack-source-id: 4326f3f
Pull Request resolved: #4260
@supriyar
Copy link
Copy Markdown
Contributor

@claude review pls

@claude
Copy link
Copy Markdown

claude Bot commented Apr 10, 2026

Claude finished @supriyar's task in 1m 29s —— View job


Review of PR #4260

  • Read the PR diff and changed files
  • Review production code changes (torchao/prototype/quantization/quant_api.py)
  • Review test changes (test/prototype/test_prototype_float8_tensor.py)
  • Post review feedback

The core idea is sound — softmax output is guaranteed to be in [0, 1], so a fixed scale of float8_max / 1.0 is both correct and simpler than running an observer during calibration. This removes unnecessary state and computation. A few items to consider:

Issues

1. _device is None if forward() is never called before convert (bug)
Float8ObservedSoftmax tracks self._device by recording input.device in forward(). If convert is called without any calibration pass (the observer is gone, so calibration is technically optional now), _device will be None, and the scale tensor will be created on CPU with device=None.

This may silently work (CPU default), but it's fragile and differs from the intended device. Consider either:

  • Defaulting _device to something explicit (e.g. "cpu") so behavior is deterministic, or
  • Adding a guard in the convert step that warns or errors if _device is None.

Since the calibration loop is still part of the documented flow (and the test does call forward()), this is a minor concern, but worth hardening. Fix this →

2. Scale reshaping with while loop in forward() runs every call
Float8QuantizedSoftmax.forward() (lines 440-442) reshapes the scale on every forward pass:

scale = self.output_act_quant_scale
while scale.ndim < output.ndim:
    scale = scale.unsqueeze(0)

This is correct but creates new tensors each call. Since the scale is a fixed buffer, consider pre-computing the reshaped scale in from_observed() or __init__() based on a known expected ndim, or caching the result. For a hot inference path this adds minor overhead per call. Not a blocker, but worth noting.

Nits

3. Unused import
After this change, AffineQuantizedMinMaxObserver is still imported at line 486 inside _float8_static_activation_float8_weight_transform. It's still used for the Linear path, so this is fine — just flagging that I checked.

4. Calibration loop is now a no-op for softmax
The test still runs 10 calibration iterations (line 520), which now only serve to record _device. A brief comment in the test or docstring noting that the calibration loop is only needed for linear modules (and that the softmax path ignores it) would improve clarity for future readers.

Test Coverage

The new test_static_quant_softmax_device_and_dims test is a good addition — it covers CPU/CUDA and 2D/3D/4D shapes, which directly validates the while scale.ndim < output.ndim reshaping logic. One thing missing: a test where convert is called without any prior forward() call, to document what happens with _device=None.

Summary

The motivation and approach are clean. The fixed scale of float8_max for softmax output in [0, 1] is mathematically correct and removes unnecessary observer machinery. The main concern is the _device=None edge case — the rest is minor.


… observer"

Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

[ghstack-poisoned]
jerryzh168 added a commit that referenced this pull request Apr 10, 2026
Summary:
Softmax output is always in [0, 1], so observing during calibration is
unnecessary. Replace the observer-based flow with a fixed output scale
of `finfo(float8_dtype).max / 1.0` (448.0 for float8_e4m3fn). This
simplifies `Float8ObservedSoftmax` to a plain marker module and removes
the observer allocation + calibration overhead for softmax layers.

Test Plan:
pytest test/prototype/test_prototype_float8_tensor.py -k test_static_quant_softmax -x

ghstack-source-id: 39b9d0e
Pull Request resolved: #4260
@jerryzh168
Copy link
Copy Markdown
Contributor Author

@claude review again

@claude
Copy link
Copy Markdown

claude Bot commented Apr 10, 2026

Claude finished @jerryzh168's task in 2m 12s —— View job


Re-review of PR #4260

  • Read the PR diff and changed files
  • Review production code changes (torchao/prototype/quantization/quant_api.py)
  • Review test changes (test/prototype/test_prototype_float8_tensor.py)
  • Post review feedback

All three concerns from the previous review have been addressed in the latest revision:

  1. _device=None edge case — Now handled with an explicit warning and CPU fallback at quant_api.py:540-545. Good.

  2. Scale reshaping per-call overhead — Replaced with _get_reshaped_scale() at quant_api.py:435-446 that caches the result and only recomputes when ndim changes. This addresses the concern cleanly.

  3. Calibration loop comment — Added at test_prototype_float8_tensor.py:520-521 explaining the loop is only needed for linear observers and that softmax ignores it.

Remaining observations

1. _reshaped_scale cache invalidation with torch.compile

The _reshaped_scale is stored as a plain Python attribute (not a buffer), so it won't be captured by torch.compile's tracing. During compiled execution, self._reshaped_scale will always be None from the graph's perspective, causing the reshape to run every time. This isn't a correctness bug — the reshape path is correct — but the caching provides no benefit under torch.compile. Since the reshape is just a few unsqueeze calls, this is fine in practice. Just noting for transparency.

2. Scale value correctness

The fixed scale of float8_max / 1.0 = 448.0 (for float8_e4m3fn) is mathematically sound. Softmax output is in [0, 1], and quantization maps max_value / scale to the float8 range. With scale = 448.0, any value in [0, 1] maps to [0, 448] in float8 representation, which is within range. Verified that PrototypeFloat8Tensor.from_hp at prototype_float8_tensor.py:183-185 asserts scale.ndim == hp_tensor.ndim, which the _get_reshaped_scale method satisfies.

3. New test test_static_quant_softmax_device_and_dims is solid

Covers CPU/CUDA and 2D/3D/4D shapes. Verifies scale device placement, output device, shape, and dtype. One minor gap: no test for the _device=None warning path (convert without prior forward()), but since that path now has a sensible default, this is low priority.

Summary

The previous review issues have been addressed. The change is clean and correct. LGTM.


View job

@jerryzh168 jerryzh168 changed the base branch from gh/jerryzh168/83/base to main April 10, 2026 23:30
@jerryzh168 jerryzh168 merged commit 5e558f5 into main Apr 11, 2026
36 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants