From 73940fead3e41d55c5dc60138224575369c660df Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 22 Oct 2025 11:44:08 -0700 Subject: [PATCH] Remove internal executorch dependency on torchao.quantization.subclass (#15223) Summary: **Summary:** This is a really old quantization API that we recently removed in torchao (D84842047). No one should be calling it anymore. For BC, let's just copy the base class into executorch for now. We should delete this in the future. **Test Plan:** CI bypass-github-export-checks Reviewed By: vkuzo Differential Revision: D84921134 --- .../models/llama/experimental/__init__.py | 0 .../llama/experimental/load_gguf_q4_0.py | 3 +- .../models/llama/experimental/subclass.py | 174 +++++++++++++++++- 3 files changed, 175 insertions(+), 2 deletions(-) create mode 100644 examples/models/llama/experimental/__init__.py diff --git a/examples/models/llama/experimental/__init__.py b/examples/models/llama/experimental/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/models/llama/experimental/load_gguf_q4_0.py b/examples/models/llama/experimental/load_gguf_q4_0.py index 39b81ea64a2..8bffde3e5fb 100644 --- a/examples/models/llama/experimental/load_gguf_q4_0.py +++ b/examples/models/llama/experimental/load_gguf_q4_0.py @@ -26,7 +26,8 @@ from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file from gguf import ReaderTensor from gguf.constants import GGMLQuantizationType -from torchao.quantization.subclass import QuantizedLinearWeightBase + +from .subclass import QuantizedLinearWeightBase FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/examples/models/llama/experimental/subclass.py b/examples/models/llama/experimental/subclass.py index 0a38af1efcf..653d7c2bf4b 100644 --- a/examples/models/llama/experimental/subclass.py +++ b/examples/models/llama/experimental/subclass.py @@ -20,7 +20,12 @@ # # This layout is handled internally in the tensor subclass. import torch -from torchao.quantization.subclass import QuantizedLinearWeightBase + +from executorch.exir._warnings import deprecated +from torch.utils._python_dispatch import return_and_correct_aliasing + + +aten = torch.ops.aten def down_size(size): @@ -129,6 +134,173 @@ def to_float( return a * scale.unsqueeze(1) +@deprecated("QuantizedLinearWeightBase is deleted from torchao. DO NOT USE!") +class QuantizedLinearWeightBase(torch.Tensor): + """ + *** LEGACY TORCHAO TENSOR SUBCLASS *** + + Note: this subclass no longer exists in torchao. No one should be importing or extending this + subclass anymore. It is added back here just for internal executorch BC. DO NOT USE! + + Base quantized tensor subclass for quantized linear weights. When the from_float method is used, + to create an instance of any QuantizedLinearWeightBase, we assume the input + weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels. + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + """ + + @staticmethod + def __new__(cls, int_data, transposed, shape, *args, **kwargs): + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + assert "dtype" in kwargs + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, int_data, transposed, *args, **kwargs): + self.int_data = int_data + + self.transposed = transposed + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + pass + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + def dequantize(self): + pass + + def int_repr(self): + pass + + def q_params(self): + pass + + def half(self): + return self.to(torch.float16) + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def _apply_fn_to_data(self, fn): + pass + + def _change_shape(self): + pass + + def __tensor_flatten__(self): + pass + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + pass + + @classmethod + def from_float(cls, input_float): + pass + + # __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + mat1, w_qtensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert not w_qtensor.transposed + return cls._quantized_op(mat1, w_qtensor, bias) + + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except Exception: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation + # for consistency and to allow people to test + # 2 - we're given non-floats - quantizing long to int8 is crazy + if ( + func in [aten.mm.default, aten.addmm.default] + and args[0].is_floating_point() + and args[0].is_cuda + ): + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + mat1, w_qtensor, bias = ( + args[1], + args[2], + args[0], + ) + else: + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + mat1, w_qtensor, bias = ( + args[0], + args[1], + None if len(args) == 2 else args[2], + ) + # call the quantized op for the specific type + # of quantized tensor subclass + return cls._quantized_op(mat1, w_qtensor, bias) + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + args[0].transposed = not args[0].transposed + new = args[0]._change_shape(args[0].shape[::-1]) + return return_and_correct_aliasing(func, args, kwargs, new) + + if func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + class GGMLInt4LinearWeight(QuantizedLinearWeightBase): """ A Tensor subclass that when applied to a weight used in a linear op/module,