Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
3 changes: 2 additions & 1 deletion examples/models/llama/experimental/load_gguf_q4_0.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from executorch.extension.gguf_util.load_gguf import GGUFWeights, load_file
from gguf import ReaderTensor
from gguf.constants import GGMLQuantizationType
from torchao.quantization.subclass import QuantizedLinearWeightBase

from .subclass import QuantizedLinearWeightBase

FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
logging.basicConfig(level=logging.INFO, format=FORMAT)
Expand Down
174 changes: 173 additions & 1 deletion examples/models/llama/experimental/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@
#
# This layout is handled internally in the tensor subclass.
import torch
from torchao.quantization.subclass import QuantizedLinearWeightBase
from torch.utils._python_dispatch import return_and_correct_aliasing

from executorch.exir._warnings import deprecated


aten = torch.ops.aten


def down_size(size):
Expand Down Expand Up @@ -129,6 +134,173 @@ def to_float(
return a * scale.unsqueeze(1)


@deprecated("QuantizedLinearWeightBase is deleted from torchao. DO NOT USE!")
class QuantizedLinearWeightBase(torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we tag this as deprecated?

from executorch.exir._warnings import deprecated
.. 
@deprecated("QuantizedLinearWeightBase is deleted from torchao ... ")
class QuantizedLinearWeightBase(torch.Tensor):

"""
*** LEGACY TORCHAO TENSOR SUBCLASS ***

Note: this subclass no longer exists in torchao. No one should be importing or extending this
subclass anymore. It is added back here just for internal executorch BC. DO NOT USE!

Base quantized tensor subclass for quantized linear weights. When the from_float method is used,
to create an instance of any QuantizedLinearWeightBase, we assume the input
weight is oriented the way it is in a normal linear op, i.e. out-channels x in-channels.

The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
regardless of the internal representation's type or orientation.
"""

@staticmethod
def __new__(cls, int_data, transposed, shape, *args, **kwargs):
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
assert "dtype" in kwargs
assert not kwargs.get("requires_grad", False)
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, int_data, transposed, *args, **kwargs):
self.int_data = int_data

self.transposed = transposed

@staticmethod
def _quantized_op(act_mat, w_qtensor, bias):
pass

def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})"
)

def dequantize(self):
pass

def int_repr(self):
pass

def q_params(self):
pass

def half(self):
return self.to(torch.float16)

def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def _apply_fn_to_data(self, fn):
pass

def _change_shape(self):
pass

def __tensor_flatten__(self):
pass

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
pass

@classmethod
def from_float(cls, input_float):
pass

# __torch_function__ = torch._C._disabled_torch_function_impl

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func is torch.nn.functional.linear:
mat1, w_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
assert not w_qtensor.transposed
return cls._quantized_op(mat1, w_qtensor, bias)

try:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except Exception:
print(f"ERR: subclass doesn't implement {func}")

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
# for consistency and to allow people to test
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].is_cuda
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
mat1, w_qtensor, bias = (
args[1],
args[2],
args[0],
)
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
mat1, w_qtensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
)
# call the quantized op for the specific type
# of quantized tensor subclass
return cls._quantized_op(mat1, w_qtensor, bias)

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
args[0].transposed = not args[0].transposed
new = args[0]._change_shape(args[0].shape[::-1])
return return_and_correct_aliasing(func, args, kwargs, new)

if func is aten._to_copy.default:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)


class GGMLInt4LinearWeight(QuantizedLinearWeightBase):
"""
A Tensor subclass that when applied to a weight used in a linear op/module,
Expand Down
Loading