Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Tensor pack_weights_cpu(
weight_scales.dtype() == torch::kFloat32,
"weight_scales must be float32");
CHECK_MSG(weight_scales.dim() == 1, "weight_scales must be 1D");
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
CHECK_MSG(
weight_scales.size(0) == ((n * k) / group_size),
"expected 1 scale per group");
Expand Down Expand Up @@ -134,9 +135,9 @@ Tensor pack_weights_without_zeros_cpu(
const Tensor& weight_qvals,
const Tensor& weight_scales,
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
// group_size is a tensor with size (0, group_size)
const Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);
int64_t group_size = group_size_tensor.size(1);
return pack_weights_cpu<weight_nbit, /*has_weight_zeros*/ false>(
weight_qvals, weight_scales, std::nullopt, group_size);
}
Expand All @@ -151,7 +152,7 @@ Tensor pack_weights_with_zeros_cpu(
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
const Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);
int64_t group_size = group_size_tensor.size(1);
return pack_weights_cpu<weight_nbit, /*has_weight_zeros*/ true>(
weight_qvals, weight_scales, weight_zeros, group_size);
}
Expand All @@ -164,6 +165,7 @@ Tensor pack_weights_meta(
const Tensor& weight_scales,
const std::optional<Tensor>& weight_zeros,
int64_t group_size) {
CHECK_MSG(group_size >= 1, "group_size must be >= 1");
int n = weight_qvals.size(0);
int k = weight_qvals.size(1);

Expand All @@ -190,7 +192,7 @@ Tensor pack_weights_without_zeros_meta(
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
const Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);
int64_t group_size = group_size_tensor.size(1);
return pack_weights_meta<weight_nbit, /*has_weight_zeros*/ false>(
weight_qvals, weight_scales, std::nullopt, group_size);
}
Expand All @@ -205,7 +207,7 @@ Tensor pack_weights_with_zeros_meta(
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
const Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);
int64_t group_size = group_size_tensor.size(1);
return pack_weights_meta<weight_nbit, /*has_weight_zeros*/ true>(
weight_qvals, weight_scales, weight_zeros, group_size);
}
Expand All @@ -216,16 +218,19 @@ template <int weight_nbit, bool has_weight_zeros>
Tensor linear_out_cpu(
const Tensor& packed_weights,
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
// int64_t when supported by AOTI Currently they are meta tensors with size
// equal to the int they wrap
// int64_t when supported by AOTI Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& n_tensor,
const Tensor& k_tensor,
const Tensor& group_size_tensor,
const Tensor& activations,
Tensor& out) {
int n = n_tensor.size(0);
int k = k_tensor.size(0);
int group_size = group_size_tensor.size(0);
int n = n_tensor.size(1);
int k = k_tensor.size(1);
int group_size = group_size_tensor.size(1);
CHECK_MSG(n >= 1, "n must be >= 1");
CHECK_MSG(k >= 1, "k must be >= 1");
CHECK_MSG(group_size >= 1, "group_size must be >= 1");

#ifdef USE_ATEN
CHECK_MSG(
Expand Down Expand Up @@ -303,8 +308,8 @@ template <int weight_nbit, bool has_weight_zeros>
Tensor linear_cpu(
const Tensor& packed_weights,
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
// int64_t when supported by AOTI Currently they are meta tensors with size
// equal to the int they wrap
// int64_t when supported by AOTI Currently they are tensors with size
// equal to (0, the int they wrap)
const Tensor& n_tensor,
const Tensor& k_tensor,
const Tensor& group_size_tensor,
Expand All @@ -327,14 +332,17 @@ Tensor linear_meta(
const Tensor& packed_weights,
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
// int64_t when supported by AOTI
// Currently they are meta tensors with size equal to the int they wrap
// Currently they are tensors with size equal to (0, the int they wrap)
const Tensor& n_tensor,
const Tensor& k_tensor,
const Tensor& group_size_tensor,
const Tensor& activations) {
int n = n_tensor.size(0);
int k = k_tensor.size(0);
int n = n_tensor.size(1);
int k = k_tensor.size(1);
CHECK_MSG(n >= 1, "n must be >= 1");
CHECK_MSG(k >= 1, "k must be >= 1");

CHECK_MSG(activations.dim() == 2, "activations must be 2D");
int m = activations.size(0);
int k_ = activations.size(1);
CHECK_MSG(k == k_, "activation shape is incompatible with packed weights.");
Expand Down
12 changes: 7 additions & 5 deletions torchao/experimental/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,11 @@ def quantize_and_pack_weights(self, weights, nbit, group_size, has_weight_zeros)

# TODO(T200095131): convert self.n, self.k, self.group_size to
# int when supported by AOTI
self._n = torch.empty(n, dtype=torch.int8)
self._k = torch.empty(k, dtype=torch.int8)
self._group_size = torch.empty(self.group_size, dtype=torch.int8)
# AOTI does not allow a tensor of size (n, 0), so we do (0, n)
self._n = torch.empty(0, n, dtype=torch.int8)
self._k = torch.empty(0, k, dtype=torch.int8)
self._group_size = torch.empty(0, group_size, dtype=torch.int8)


weight_qvals, weight_scales, weight_zeros = _quantize(
weights, self.group_size, self.nbit, self.has_weight_zeros
Expand All @@ -109,7 +111,7 @@ def forward(self, x):
assert x.dim() >= 3
lead_shape = x.shape[0:-2]
m, k = x.shape[-2], x.shape[-1]
n = self._n.shape[0]
n = self._n.shape[1]
x = x.reshape(-1, m, k)

res = [
Expand Down Expand Up @@ -254,7 +256,7 @@ def _replace_linear_with_quantized_linear(module: nn.Module, kwargs={}):
if not isinstance(qlinear, _Int8DynActIntxWeightQuantizedLinearNative):
raise e
logger.warning(
"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n"
f"_Int8DynActIntxWeightQuantizedLinearNative raised an exception during quantize_and_pack_weights: {e}\n"
+ "Falling back to **slow** implementation _Int8DynActIntxWeightQuantizedLinearFallback."
)
qlinear = _Int8DynActIntxWeightQuantizedLinearFallback()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,16 +77,17 @@ def test_accuracy(self):

def test_export_compile_aoti(self):
group_size = 32
m = 1
n = 256
k = 256
m = 3
k0 = 512
k1 = 256
k2 = 128
k3 = 1024
nbit = 4
has_weight_zeros = False
n_layers = 3
layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)]
layers = [torch.nn.Linear(k0, k1, bias=False), torch.nn.Linear(k1, k2, bias=False), torch.nn.Linear(k2, k3, bias=False)]
model = torch.nn.Sequential(*layers)

activations = torch.randn(m, k, dtype=torch.float32)
activations = torch.randn(2, 1, m, k0, dtype=torch.float32)

print("Quantizing model")
quantizer = Int8DynActIntxWeightQuantizer(
Expand Down