- 
        Couldn't load subscription status. 
- Fork 354
Closed
Description
We should have a consistent API for converting TorchAOBaseTensor children back to high precision.  This is useful to simplify 3p integrations (such as vllm-project/vllm#25480), where we may want to convert to bfloat16 to map to weight-only quantization as a fallback path.  This also seems generally useful to increase consistency in the codebase.
My tentative proposal is to do one of the following for all quantized tensors in torchao:
# option 1 - require target dtype
def to_dtype(self, target_dtype: torch.dtype) -> torch.Tensor:
    ... returns tensor dequantized to `target_dtype`
# option 2 - with fallback path using dtype the tensor was created from
def to_hp(self, target_dtype: torch.dtype) -> torch.Tensor:
    if target_dtype is None:
        target_dtype = self.dtype  # original high precision dtype used to create the tensor
    ... returns tensor dequantized to `target_dtype`Some caveats:
- we should be careful when using dequantize, as this op is defined in PyTorch as always dequantizing to float32 (link), and we usually want to dequantize to original dtype, which is oftentorch.bfloat16
- we cannot use to, as our tensor subclasses usually setself.dtypeto be the dtype of the original high precision tensor before quantization.
Metadata
Metadata
Assignees
Labels
No labels