Skip to content

Add keepdim support for AffineQuantizedMinMaxObserver#3748

Merged
jerryzh168 merged 1 commit intomainfrom
gh/jerryzh168/29/head
Jan 29, 2026
Merged

Add keepdim support for AffineQuantizedMinMaxObserver#3748
jerryzh168 merged 1 commit intomainfrom
gh/jerryzh168/29/head

Conversation

@jerryzh168
Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 commented Jan 28, 2026

Stack from ghstack (oldest at bottom):

Summary:

This PR adds a keepdim parameter to AffineQuantizedMinMaxObserver and related functions, allowing users to preserve dimensions in quantization parameters (scale/zero_point) that match the shape of min/max
statistics.

Motivation

When computing quantization parameters, it's useful to maintain dimension alignment between min/max statistics and the resulting scale/zero_point tensors. This simplifies broadcasting operations and makes the
tensors easier to work with in downstream quantization workflows.

This for support per tensor float8 quantization, with 3D inputs

Changes

  1. Core Quantization Primitives
  • torchao/quantization/quant_primitives.py:
    • Added keepdim: bool = False parameter to choose_qparams_affine_with_min_max()
    • When keepdim=True, scale/zero_point retain the same shape as min_val/max_val
  1. Observer Base Class
  • torchao/quantization/observer.py:
    • Added keepdim: bool = False parameter to AffineQuantizedObserverBase.init()
    • Stored as self.keepdim for use in derived classes
  1. MinMax Observer
  • torchao/quantization/observer.py:
    • Updated AffineQuantizedMinMaxObserver.forward() to use keepdim=self.keepdim in torch.amin/amax calls
    • Updated AffineQuantizedMinMaxObserver.calculate_qparams() to pass keepdim to choose_qparams_affine_with_min_max()
  1. Tests
  • test/quantization/test_observer.py:
    • Added test_keepdim_per_tensor(): Verifies keepdim behavior for per-tensor quantization
    • Added test_keepdim_per_axis(): Verifies keepdim behavior for per-axis quantization
    • Tests confirm that with keepdim=True, scale/zero_point shapes match min_val/max_val shapes

Behavior

With keepdim=False (default, backward compatible):
min_val.shape = [] # scalar
scale.shape = [] # scalar

min_val.shape = [10]
scale.shape = [10]

With keepdim=True:
min_val.shape = [1, 1]
scale.shape = [1, 1]

min_val.shape = [10, 1]
scale.shape = [10, 1]

Backward Compatibility

✅ Fully backward compatible - keepdim defaults to False, preserving existing behavior.

Testing

  • Unit tests added for both PerTensor (both 2D and 3D inputs) and PerAxis granularities
  • Tests verify correct shapes and equivalent values between keepdim=True/False

Test Plan:
pytest test/quantization/test_observer.py -k keepdim

Reviewers:

Subscribers:

Tasks:

Tags:

Summary:

This PR adds a keepdim parameter to AffineQuantizedMinMaxObserver and related functions, allowing users to preserve dimensions in quantization parameters (scale/zero_point) that match the shape of min/max
statistics.

Motivation

When computing quantization parameters, it's useful to maintain dimension alignment between min/max statistics and the resulting scale/zero_point tensors. This simplifies broadcasting operations and makes the
tensors easier to work with in downstream quantization workflows.

This for support per tensor float8 quantization, with 3D inputs

Changes

1. Core Quantization Primitives

- torchao/quantization/quant_primitives.py:
  - Added keepdim: bool = False parameter to choose_qparams_affine_with_min_max()
  - When keepdim=True, scale/zero_point retain the same shape as min_val/max_val

2. Observer Base Class

- torchao/quantization/observer.py:
  - Added keepdim: bool = False parameter to AffineQuantizedObserverBase.__init__()
  - Stored as self.keepdim for use in derived classes

3. MinMax Observer

- torchao/quantization/observer.py:
  - Updated AffineQuantizedMinMaxObserver.forward() to use keepdim=self.keepdim in torch.amin/amax calls
  - Updated AffineQuantizedMinMaxObserver.calculate_qparams() to pass keepdim to choose_qparams_affine_with_min_max()

4. Tests

- test/quantization/test_observer.py:
  - Added test_keepdim_per_tensor(): Verifies keepdim behavior for per-tensor quantization
  - Added test_keepdim_per_axis(): Verifies keepdim behavior for per-axis quantization
  - Tests confirm that with keepdim=True, scale/zero_point shapes match min_val/max_val shapes

Behavior

With keepdim=False (default, backward compatible):
min_val.shape = []  # scalar
scale.shape = []    # scalar

min_val.shape = [10]
scale.shape = [10]

With keepdim=True:
min_val.shape = [1, 1]
scale.shape = [1, 1]

min_val.shape = [10, 1]
scale.shape = [10, 1]

Backward Compatibility

✅ Fully backward compatible - keepdim defaults to False, preserving existing behavior.

Testing

- Unit tests added for both PerTensor (both 2D and 3D inputs) and PerAxis granularities
- Tests verify correct shapes and equivalent values between keepdim=True/False

Test Plan:
pytest test/quantization/test_observer.py -k keepdim

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 28, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3748

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 3 Pending

As of commit 7b673a8 with merge base a003c32 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

jerryzh168 added a commit that referenced this pull request Jan 28, 2026
Summary:

This PR adds a keepdim parameter to AffineQuantizedMinMaxObserver and related functions, allowing users to preserve dimensions in quantization parameters (scale/zero_point) that match the shape of min/max
statistics.

Motivation

When computing quantization parameters, it's useful to maintain dimension alignment between min/max statistics and the resulting scale/zero_point tensors. This simplifies broadcasting operations and makes the
tensors easier to work with in downstream quantization workflows.

This for support per tensor float8 quantization, with 3D inputs

Changes

1. Core Quantization Primitives

- torchao/quantization/quant_primitives.py:
  - Added keepdim: bool = False parameter to choose_qparams_affine_with_min_max()
  - When keepdim=True, scale/zero_point retain the same shape as min_val/max_val

2. Observer Base Class

- torchao/quantization/observer.py:
  - Added keepdim: bool = False parameter to AffineQuantizedObserverBase.__init__()
  - Stored as self.keepdim for use in derived classes

3. MinMax Observer

- torchao/quantization/observer.py:
  - Updated AffineQuantizedMinMaxObserver.forward() to use keepdim=self.keepdim in torch.amin/amax calls
  - Updated AffineQuantizedMinMaxObserver.calculate_qparams() to pass keepdim to choose_qparams_affine_with_min_max()

4. Tests

- test/quantization/test_observer.py:
  - Added test_keepdim_per_tensor(): Verifies keepdim behavior for per-tensor quantization
  - Added test_keepdim_per_axis(): Verifies keepdim behavior for per-axis quantization
  - Tests confirm that with keepdim=True, scale/zero_point shapes match min_val/max_val shapes

Behavior

With keepdim=False (default, backward compatible):
min_val.shape = []  # scalar
scale.shape = []    # scalar

min_val.shape = [10]
scale.shape = [10]

With keepdim=True:
min_val.shape = [1, 1]
scale.shape = [1, 1]

min_val.shape = [10, 1]
scale.shape = [10, 1]

Backward Compatibility

✅ Fully backward compatible - keepdim defaults to False, preserving existing behavior.

Testing

- Unit tests added for both PerTensor (both 2D and 3D inputs) and PerAxis granularities
- Tests verify correct shapes and equivalent values between keepdim=True/False

Test Plan:
pytest test/quantization/test_observer.py -k keepdim

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 04cb012
Pull Request resolved: #3748
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 28, 2026
@jerryzh168 jerryzh168 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Jan 28, 2026
@jerryzh168 jerryzh168 changed the base branch from gh/jerryzh168/29/base to main January 28, 2026 23:57
@jerryzh168 jerryzh168 merged commit c6c34d4 into main Jan 29, 2026
36 of 38 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants