Skip to content

[RFC] A Proposal to Support More Accelerators for TorchAO #3044

@orangeH25

Description

@orangeH25

Summary

We propose a modification plan about extending TorchAO to support more accelerators easier.

Key points:

  • Generic quantization path – a universal implementation available for all backends.
    • Quantization stage: Lower performance priority. Implemented using standard torch.ops composition to ensure generality and compatibility.
    • Inference stage: Performance-sensitive. Higher priority. Uses PyTorch’s operator registration/dispatch mechanism to balance performance and portability.
  • Specific hardware path – for backends that require custom, non-generic implementations.
    • Example: The current XPU INT4 weight quantization branch, which uses a backend-specific implementation.

Current Status

Take INT4 weight quantization as an example.

Current state:

  • FORMAT: PLAIN → Intended as a multi-device generic format, but currently restricted to CUDA due to FBGEMM limitations.
  • FORMAT: PLAIN_INT32 → A custom quantization format path specific to the XPU backend.

Proposed Implementation

We continue with INT4 weight quantization as the running example.

Generic Quantization Path

Quantization Stage

The quantization stage mainly consists of two types of operations: quantization and packing.

These operations are implemented in AO with standard Torch operators to guarantee portability:

def int4_row_quantize_zp(
    x: torch.Tensor,
    group_size: int = 128,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    # quantize operator, implemented with standard torch ops
    ...

def pack_int4(x: torch.Tensor) -> torch.Tensor:
    # pack operator, implemented with standard torch ops
    ...	

Inference Stage

Image

Take bf16i4bf16_rowwise Operator as example:

AO-side Implementation (TorchAO Repo)

  1. Register the backend-specific operator signature through PyTorch’s operator registration system:
TORCH_LIBRARY(ao, m) {
    m.def("bf16i4bf16_rowwise(Tensor input, Tensor weight, Tensor scale, Tensor zero_point) -> Tensor");
}
  1. Provide a generic fallback implementation using standard Torch operators:
at::Tensor bf16i4bf16_rowwise(
    at::Tensor X, // BF16
    at::Tensor W, // INT4
    at::Tensor w_scale_group,
    at::Tensor w_zero_group
) {
	// Performs BF16 × INT4 computation  
	// Uses standard PyTorch operators to enable a generic fallback implementation
    ...
}

Device-side Implementation (Device Repo)

  • Register a device-specific kernel with AO. During inference, this implementation will be prioritized:
TORCH_LIBRARY_IMPL(ao, device, m) {
    m.impl("bf16i4bf16_rowwise", bf16i4bf16_rowwise);
}

When this operator is invoked in AO, the dispatcher selects the device-specific implementation if available. Otherwise, it falls back to the generic implementation:

res = torch.ops.ao.bf16i4bf16_rowwise(
    input_tensor,
    weight_tensor.qdata,
    weight_tensor.scale,
    weight_tensor.zero_point,
)

Specific hardware Path

For devices requiring custom, non-generic implementations, a separate branch can be created. The existing XPU implementation serves as a reference:

elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
	# device-specific path for PLAIN_INT32 format
	new_weight = Int4PlainInt32Tensor.from_hp(
    	weight,
        block_size,
    )
    return new_weight

Benefits

Improved extensibility: By using standard PyTorch operators, hardware backends can leverage quantization support while ensuring wide portability.

Note on Scope

This RFC proposes a general extensible quantization framework. INT4 weight quantization is used as an example, but the approach is generic. The initial implementation may focus on INT4, with later extensions to INT8, FLOAT8, or other quantization schemes.

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions