Skip to content

Support k_quant quantization in Olive #1818

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions olive/constants.py
Original file line number Diff line number Diff line change
@@ -67,6 +67,8 @@ class QuantAlgorithm(CaseInsensitiveEnum):
RTN = "rtn"
SPINQUANT = "spinquant"
QUAROT = "quarot"
K_QUANT_MIXED = "k_quant_mixed" # k_quant algorithm with mixed precision
K_QUANT_LAST = "k_quant_last" # k_quant algorithm, only the last MatMul /lm_head/MatMul is quantized as int8, other MatMuls are quantized as int4.


class QuantEncoding(StrEnumBase):
51 changes: 44 additions & 7 deletions olive/passes/onnx/quantization.py
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@
from typing import Callable, Optional, Union

import onnx
from onnxruntime import __version__ as OrtVersion
from packaging import version

from olive.common.config_utils import validate_config
@@ -367,8 +368,6 @@ def _run_for_config(
logger.info("Model has adapters which should not be quantized. Returning the model without quantization.")
return model

from onnxruntime import __version__ as OrtVersion

if version.parse(OrtVersion) < version.parse("1.18.0"):
raise ValueError("Onnx Quantization is only supported for onnxruntime>=1.18.0")

@@ -824,8 +823,6 @@ def _validators(cls) -> dict[str, Callable]:
def _run_for_config(
self, model: ONNXModelHandler, config: type[BasePassConfig], output_model_path: str
) -> ONNXModelHandler:
from onnxruntime import __version__ as OrtVersion

if version.parse(OrtVersion) < version.parse("1.18.0"):
raise ValueError("MatMul4BitsQuantizer is only supported for onnxruntime>=1.18.0")

@@ -852,6 +849,12 @@ def _run_for_config(
QuantAlgorithm.GPTQ: GPTQWeightOnlyQuantConfig,
}

if version.parse(OrtVersion) > version.parse("1.22.0"):
from onnxruntime.quantization.matmul_nbits_quantizer import KQuantWeightOnlyQuantConfig

algo_to_config[QuantAlgorithm.K_QUANT_MIXED] = KQuantWeightOnlyQuantConfig
algo_to_config[QuantAlgorithm.K_QUANT_LAST] = KQuantWeightOnlyQuantConfig

if model_has_adapters(model.model_path) and config.algorithm:
logger.info(
"Model has adapters which should only be quantized with algorithm=None. Got %s. Returning"
@@ -874,6 +877,8 @@ def _run_for_config(
raise ValueError(f"MatMulNBitsQuantizer {key} is only supported for onnxruntime>=1.20.0")
kwargs[key] = value

onnx_model = model.load_model()

if woq_config_class := algo_to_config.get(config.algorithm, None):
algo_config = config.weight_only_quant_configs or {}
for key in inspect.signature(woq_config_class.__init__).parameters:
@@ -893,11 +898,40 @@ def _run_for_config(
algo_config[key] = kwargs[key]
if config.algorithm == QuantAlgorithm.GPTQ:
algo_config["calibration_data_reader"] = get_calibration_dataloader(config)
elif config.algorithm == QuantAlgorithm.K_QUANT_MIXED:
node_names = [node.name for node in onnx_model.graph.node]

import re

pattern = r"/model/layers\.(\d{1,2})/mlp/down_proj/MatMul"
layers = set()
for s in node_names:
match = re.search(pattern, s)
if match:
layer_number = int(match.group(1))
layers.add(layer_number)
n_layers = len(layers)
layers_to_exclude = [
i
for i in range(n_layers)
if i < n_layers / 8 or i >= 7 * n_layers / 8 or (i - (round)(n_layers / 8)) % 3 == 2
]
customized_weight_config = {}
for i in layers_to_exclude:
customized_weight_config["/model/layers." + str(i) + "/attn/qkv_proj/MatMul"] = {"bits": 8}
customized_weight_config["/model/layers." + str(i) + "/mlp/down_proj/MatMul"] = {"bits": 8}
# Gemma model
customized_weight_config["/model/layers." + str(i) + "/attn/v_proj/MatMul"] = {"bits": 8}
customized_weight_config["/lm_head/MatMul"] = {"bits": 8}
algo_config["customized_weight_config"] = customized_weight_config
elif config.algorithm == QuantAlgorithm.K_QUANT_LAST:
customized_weight_config = {"/lm_head/MatMul": {"bits": 8}}
algo_config["customized_weight_config"] = customized_weight_config
kwargs["algo_config"] = woq_config_class(**algo_config)
else:
kwargs["algo_config"] = None

quant = MatMulNBitsQuantizer(model.load_model(), **kwargs)
quant = MatMulNBitsQuantizer(onnx_model, **kwargs)
quant.process()
# topologically sort the graph at the end since previous optimizations may have broken it
quant.model.topological_sort()
@@ -918,7 +952,7 @@ def _validate_weight_only_quant_config(v, values, field):
config_keys = list(v.keys())
if not values["algorithm"]:
default_config_keys = ["block_size", "is_symmetric", "accuracy_level"]
elif values["algorithm"] == QuantAlgorithm.RTN:
elif values["algorithm"] in [QuantAlgorithm.RTN, QuantAlgorithm.K_QUANT_MIXED, QuantAlgorithm.K_QUANT_LAST]:
default_config_keys = ["ratios"]
elif values["algorithm"] == QuantAlgorithm.GPTQ:
default_config_keys = ["percdamp", "block_size", "actorder", "mse", "perchannel"]
@@ -939,7 +973,10 @@ def _validate_weight_only_quant_config(v, values, field):


def _validate_algorithm(v, values, field):
if v not in (None, QuantAlgorithm.RTN, QuantAlgorithm.GPTQ):
valid_algorithm = [None, QuantAlgorithm.RTN, QuantAlgorithm.GPTQ]
if version.parse(OrtVersion) > version.parse("1.22.0"):
valid_algorithm.extend([QuantAlgorithm.K_QUANT_MIXED, QuantAlgorithm.K_QUANT_LAST])
if v not in valid_algorithm:
raise ValueError(f"Unsupported algorithm: {v}")
return v

19 changes: 19 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,8 @@
# --------------------------------------------------------------------------
import json
import os
import re
import subprocess

from setuptools import find_packages, setup

@@ -28,6 +30,18 @@
return json.load(fp)["extra_dependencies"]


def get_cuda_version():
"""Detect CUDA version from nvcc."""
try:
output = subprocess.check_output(["nvcc", "--version"]).decode()
match = re.search(r"release (\d+)\.", output)
if match:
return int(match.group(1)) # 11, 12, etc.
except Exception:

Check notice

Code scanning / CodeQL

Empty except Note

'except' clause does nothing but pass and there is no explanatory comment.

Copilot Autofix

AI 3 months ago

To fix the issue, we should handle the exception in a meaningful way. The best approach is to log the exception so that developers can understand why the function failed to detect the CUDA version. This can be done using Python's built-in logging module. Additionally, we can add a comment to clarify that returning None is the intended fallback behavior when CUDA is not available or the detection fails.


Suggested changeset 1
setup.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/setup.py b/setup.py
--- a/setup.py
+++ b/setup.py
@@ -39,4 +39,6 @@
             return int(match.group(1))  # 11, 12, etc.
-    except Exception:
-        pass
+    except Exception as e:
+        # Log the exception and return None as the fallback when CUDA detection fails.
+        import logging
+        logging.warning(f"Failed to detect CUDA version: {e}")
     return None
EOF
@@ -39,4 +39,6 @@
return int(match.group(1)) # 11, 12, etc.
except Exception:
pass
except Exception as e:
# Log the exception and return None as the fallback when CUDA detection fails.
import logging
logging.warning(f"Failed to detect CUDA version: {e}")
return None
Copilot is powered by AI and may make mistakes. Always verify output.
pass
return None


# use techniques described at https://packaging.python.org/en/latest/guides/single-sourcing-package-version/
# Don't use technique 6 since it needs extra dependencies.
VERSION = get_version("olive/__init__.py")
@@ -36,6 +50,11 @@
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt")) as req_file:
requirements = req_file.read().splitlines()

cuda_version = get_cuda_version()
if cuda_version == 11:
requirements.append("cupy-cuda11x")
elif cuda_version == 12:
requirements.append("cupy-cuda12x")

CLASSIFIERS = [
"Development Status :: 3 - Alpha",