Skip to content

TorchAOBaseTensor should provide an interface to convert to high precision tensor #3118

@vkuzo

Description

@vkuzo

We should have a consistent API for converting TorchAOBaseTensor children back to high precision. This is useful to simplify 3p integrations (such as vllm-project/vllm#25480), where we may want to convert to bfloat16 to map to weight-only quantization as a fallback path. This also seems generally useful to increase consistency in the codebase.

My tentative proposal is to do one of the following for all quantized tensors in torchao:

# option 1 - require target dtype
def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
    ... returns tensor dequantized to `target_dtype`

# option 2 - with fallback path using dtype the tensor was created from
def to_hp(self, target_dtype: torch.dtype) -> torch.Tensor:
    if target_dtype is None:
        target_dtype = self.dtype  # original high precision dtype used to create the tensor
    ... returns tensor dequantized to `target_dtype`

Some caveats:

  1. we should be careful when using dequantize, as this op is defined in PyTorch as always dequantizing to float32 (link), and we usually want to dequantize to original dtype, which is often torch.bfloat16
  2. we cannot use to, as our tensor subclasses usually set self.dtype to be the dtype of the original high precision tensor before quantization.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions