-
Notifications
You must be signed in to change notification settings - Fork 349
Description
Summary
We propose a modification plan about extending TorchAO to support more accelerators easier.
Key points:
- Generic quantization path – a universal implementation available for all backends.
- Quantization stage: Lower performance priority. Implemented using standard
torch.ops
composition to ensure generality and compatibility. - Inference stage: Performance-sensitive. Higher priority. Uses PyTorch’s operator registration/dispatch mechanism to balance performance and portability.
- Quantization stage: Lower performance priority. Implemented using standard
- Specific hardware path – for backends that require custom, non-generic implementations.
- Example: The current XPU INT4 weight quantization branch, which uses a backend-specific implementation.
Current Status
Take INT4 weight quantization as an example.
Current state:
- FORMAT: PLAIN → Intended as a multi-device generic format, but currently restricted to CUDA due to FBGEMM limitations.
- FORMAT: PLAIN_INT32 → A custom quantization format path specific to the XPU backend.
Proposed Implementation
We continue with INT4 weight quantization as the running example.
Generic Quantization Path
Quantization Stage
The quantization stage mainly consists of two types of operations: quantization and packing.
These operations are implemented in AO with standard Torch operators to guarantee portability:
def int4_row_quantize_zp(
x: torch.Tensor,
group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# quantize operator, implemented with standard torch ops
...
def pack_int4(x: torch.Tensor) -> torch.Tensor:
# pack operator, implemented with standard torch ops
...
Inference Stage

Take bf16i4bf16_rowwise
Operator as example:
AO-side Implementation (TorchAO Repo)
- Register the backend-specific operator signature through PyTorch’s operator registration system:
TORCH_LIBRARY(ao, m) {
m.def("bf16i4bf16_rowwise(Tensor input, Tensor weight, Tensor scale, Tensor zero_point) -> Tensor");
}
- Provide a generic fallback implementation using standard Torch operators:
at::Tensor bf16i4bf16_rowwise(
at::Tensor X, // BF16
at::Tensor W, // INT4
at::Tensor w_scale_group,
at::Tensor w_zero_group
) {
// Performs BF16 × INT4 computation
// Uses standard PyTorch operators to enable a generic fallback implementation
...
}
Device-side Implementation (Device Repo)
- Register a device-specific kernel with AO. During inference, this implementation will be prioritized:
TORCH_LIBRARY_IMPL(ao, device, m) {
m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
}
When this operator is invoked in AO, the dispatcher selects the device-specific implementation if available. Otherwise, it falls back to the generic implementation:
res = torch.ops.ao.bf16i4bf16_rowwise(
input_tensor,
weight_tensor.qdata,
weight_tensor.scale,
weight_tensor.zero_point,
)
Specific hardware Path
For devices requiring custom, non-generic implementations, a separate branch can be created. The existing XPU implementation serves as a reference:
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
# device-specific path for PLAIN_INT32 format
new_weight = Int4PlainInt32Tensor.from_hp(
weight,
block_size,
)
return new_weight
Benefits
Improved extensibility: By using standard PyTorch operators, hardware backends can leverage quantization support while ensuring wide portability.
Note on Scope
This RFC proposes a general extensible quantization framework. INT4 weight quantization is used as an example, but the approach is generic. The initial implementation may focus on INT4, with later extensions to INT8, FLOAT8, or other quantization schemes.