diff --git a/examples/models/llama/experimental/__init__.py b/examples/models/llama/experimental/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/examples/models/llama/experimental/load_gguf_q4_0.py b/examples/models/llama/experimental/load_gguf_q4_0.py index 39b81ea64a2..8bffde3e5fb 100644 --- a/examples/models/llama/experimental/load_gguf_q4_0.py +++ b/examples/models/llama/experimental/load_gguf_q4_0.py @@ -26,7 +26,8 @@ from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file from gguf import ReaderTensor from gguf.constants import GGMLQuantizationType -from torchao.quantization.subclass import QuantizedLinearWeightBase + +from .subclass import QuantizedLinearWeightBase FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" logging.basicConfig(level=logging.INFO, format=FORMAT) diff --git a/examples/models/llama/experimental/subclass.py b/examples/models/llama/experimental/subclass.py index 0a38af1efcf..faa70120791 100644 --- a/examples/models/llama/experimental/subclass.py +++ b/examples/models/llama/experimental/subclass.py @@ -20,7 +20,11 @@ # # This layout is handled internally in the tensor subclass. import torch -from torchao.quantization.subclass import QuantizedLinearWeightBase +from torch.utils._python_dispatch import return_and_correct_aliasing +from typing_extensions import deprecated + + +aten = torch.ops.aten def down_size(size): @@ -129,6 +133,173 @@ def to_float( return a * scale.unsqueeze(1) +@deprecated("QuantizedLinearWeightBase is deleted from torchao. DO NOT USE!") +class QuantizedLinearWeightBase(torch.Tensor): + """ + *** LEGACY TORCHAO TENSOR SUBCLASS *** + + Note: this subclass no longer exists in torchao. No one should be importing or extending this + subclass anymore. It is added back here just for internal executorch BC. DO NOT USE! + + Base quantized tensor subclass for quantized linear weights. When the from_float method is used, + to create an instance of any QuantizedLinearWeightBase, we assume the input + weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels. + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + """ + + @staticmethod + def __new__(cls, int_data, transposed, shape, *args, **kwargs): + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + assert "dtype" in kwargs + assert not kwargs.get("requires_grad", False) + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__(self, int_data, transposed, *args, **kwargs): + self.int_data = int_data + + self.transposed = transposed + + @staticmethod + def _quantized_op(act_mat, w_qtensor, bias): + pass + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + def dequantize(self): + pass + + def int_repr(self): + pass + + def q_params(self): + pass + + def half(self): + return self.to(torch.float16) + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def _apply_fn_to_data(self, fn): + pass + + def _change_shape(self): + pass + + def __tensor_flatten__(self): + pass + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + pass + + @classmethod + def from_float(cls, input_float): + pass + + # __torch_function__ = torch._C._disabled_torch_function_impl + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + mat1, w_qtensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert not w_qtensor.transposed + return cls._quantized_op(mat1, w_qtensor, bias) + + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except Exception: + print(f"ERR: subclass doesn't implement {func}") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation + # for consistency and to allow people to test + # 2 - we're given non-floats - quantizing long to int8 is crazy + if ( + func in [aten.mm.default, aten.addmm.default] + and args[0].is_floating_point() + and args[0].is_cuda + ): + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + mat1, w_qtensor, bias = ( + args[1], + args[2], + args[0], + ) + else: + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + mat1, w_qtensor, bias = ( + args[0], + args[1], + None if len(args) == 2 else args[2], + ) + # call the quantized op for the specific type + # of quantized tensor subclass + return cls._quantized_op(mat1, w_qtensor, bias) + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.t.default: + args[0].transposed = not args[0].transposed + new = args[0]._change_shape(args[0].shape[::-1]) + return return_and_correct_aliasing(func, args, kwargs, new) + + if func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + class GGMLInt4LinearWeight(QuantizedLinearWeightBase): """ A Tensor subclass that when applied to a weight used in a linear op/module,