Skip to content
This repository was archived by the owner on Sep 10, 2025. It is now read-only.

Commit 1980a69

Browse files
authored
Add torchao (#1182)
* init * update install utils * update * update libs * update torchao pin * fix ci test * add python et install to ci * fix ci errors * fixes * fixes * fixes * fixes * fixes * fixes * fixes
1 parent e4b36f9 commit 1980a69

File tree

11 files changed

+451
-123
lines changed

11 files changed

+451
-123
lines changed

.github/workflows/pull.yml

Lines changed: 232 additions & 111 deletions
Large diffs are not rendered by default.

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ __pycache__/
1414
# Build directories
1515
build/android/*
1616
et-build/*
17+
torchao-build/*
1718
runner-et/cmake-out/*
1819
runner-aoti/cmake-out/*
1920
cmake-out/

docs/quantization.md

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,75 @@ python3 torchchat.py export llama3 --quantize '{"embedding": {"bitwidth": 4, "gr
118118
python3 torchchat.py generate llama3 --pte-path llama3.pte --prompt "Hello my name is"
119119
```
120120

121+
## Experimental TorchAO lowbit kernels
122+
123+
### Use
124+
The quantization scheme a8wxdq dynamically quantizes activations to 8 bits, and quantizes the weights in a groupwise manner with a specified bitwidth and groupsize.
125+
It takes arguments bitwidth (2, 3, 4, 5, 6, 7), groupsize, and has_weight_zeros (true, false).
126+
The argument has_weight_zeros indicates whether the weights are quantized with scales only (has_weight_zeros: false) or with both scales and zeros (has_weight_zeros: true).
127+
Roughly speaking, {bitwidth: 4, groupsize: 256, has_weight_zeros: false} is similar to GGML's Q40 quantization scheme.
128+
129+
You should expect high performance on ARM CPU if bitwidth is 2, 3, 4, or 5 and groupsize is divisible by 16. With other platforms and argument choices, a slow fallback kernel will be used. You will see warnings about this during quantization.
130+
131+
### Setup
132+
To use a8wxdq, you must set up the torchao experimental kernels. These will only work on devices with ARM CPUs, for example on Mac computers with Apple Silicon.
133+
134+
From the torchchat root directory, run
135+
```
136+
sh torchchat/utils/scripts/build_torchao_ops.sh
137+
```
138+
139+
This should take about 10 seconds to complete. Once finished, you can use a8wxdq in torchchat.
140+
141+
Note: if you want to use the new kernels in the AOTI and C++ runners, you must pass the flag link_torchao when running the scripts the build the runners.
142+
143+
```
144+
sh torchchat/utils/scripts/build_native.sh aoti link_torchao_ops
145+
```
146+
147+
```
148+
sh torchchat/utils/scripts/build_native.sh et link_torchao_ops
149+
```
150+
151+
Note before running `sh torchchat/utils/scripts/build_native.sh et link_torchao_ops`, you must first install executorch with `sh torchchat/utils/scripts/install_et.sh` if you have not done so already.
152+
153+
### Examples
154+
155+
Below we show how to use the new kernels. Except for ExecuTorch, you can specify the number of threads used by setting OMP_NUM_THREADS (as is the case with PyTorch in general). Doing so is optional and a default number of threads will be chosen automatically if you do not specify.
156+
157+
#### Eager mode
158+
```
159+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --prompt "Once upon a time," --num-samples 5
160+
```
161+
162+
#### torch.compile
163+
```
164+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --compile --prompt "Once upon a time," --num-samples 5
165+
```
166+
167+
#### AOTI
168+
```
169+
OMP_NUM_THREADS=6 python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-dso llama3_1.so
170+
OMP_NUM_THREADS=6 python3 torchchat.py generate llama3.1 --dso-path llama3_1.so --prompt "Once upon a time," --num-samples 5
171+
```
172+
173+
If you built the AOTI runner with link_torchao_ops as discussed in the setup section, you can also use the C++ runner:
174+
175+
```
176+
OMP_NUM_THREADS=6 ./cmake-out/aoti_run llama3_1.so -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
177+
```
178+
179+
#### ExecuTorch
180+
```
181+
python torchchat.py export llama3.1 --device cpu --dtype float32 --quantize '{"linear:a8wxdq": {"bitwidth": 4, "groupsize": 256, "has_weight_zeros": false}}' --output-pte llama3_1.pte
182+
```
183+
184+
Note: only the ExecuTorch C++ runner in torchchat when built using the instructions in the setup can run the exported *.pte file. It will not work with the `python torchchat.py generate` command.
185+
186+
```
187+
./cmake-out/et_run llama3_1.pte -z $HOME/.torchchat/model-cache/meta-llama/Meta-Llama-3.1-8B-Instruct/tokenizer.model -l 3 -i "Once upon a time,"
188+
```
189+
121190
## Quantization Profiles
122191

123192
Four [sample profiles](https://github.com/pytorch/torchchat/tree/main/torchchat/quant_config/) are included with the torchchat distribution: `cuda.json`, `desktop.json`, `mobile.json`, `pi5.json`

install/.pins/torchao-pin.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
63cb7a9857654784f726fec75c0dc36167094d8a

runner/aoti.cmake

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,7 @@ if(Torch_FOUND)
2828
target_link_libraries(aoti_run "${TORCH_LIBRARIES}" m)
2929
set_property(TARGET aoti_run PROPERTY CXX_STANDARD 17)
3030
endif()
31+
32+
if (LINK_TORCHAO_OPS)
33+
target_link_libraries(aoti_run "${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/liblinear_a8wxdq_ATEN${CMAKE_SHARED_LIBRARY_SUFFIX}")
34+
endif()

runner/et.cmake

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ if(executorch_FOUND)
116116
target_link_libraries(et_run PRIVATE log)
117117
endif()
118118

119+
if(LINK_TORCHAO_OPS)
120+
target_link_libraries(et_run PRIVATE "$<LINK_LIBRARY:WHOLE_ARCHIVE,${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/liblinear_a8wxdq_EXECUTORCH.a>")
121+
target_link_libraries(et_run PRIVATE
122+
"${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_kernels_aarch64.a"
123+
"${TORCHCHAT_ROOT}/torchao-build/cmake-out/lib/libtorchao_ops_linear_EXECUTORCH.a"
124+
)
125+
endif()
126+
119127
else()
120128
MESSAGE(WARNING "ExecuTorch package not found")
121129
endif()

torchchat/utils/quantize.py

Lines changed: 43 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,10 +96,19 @@ def quantize_model(
9696
precision = get_precision()
9797

9898
try:
99-
# Easier to ask forgiveness than permission
100-
quant_handler = ao_quantizer_class_dict[quantizer](
101-
groupsize=q_kwargs["groupsize"], device=device, precision=precision
102-
)
99+
if quantizer == "linear:a8wxdq":
100+
quant_handler = ao_quantizer_class_dict[quantizer](
101+
device=device,
102+
precision=precision,
103+
bitwidth=q_kwargs.get("bitwidth", 4),
104+
groupsize=q_kwargs.get("groupsize", 128),
105+
has_weight_zeros=q_kwargs.get("has_weight_zeros", False),
106+
)
107+
else:
108+
# Easier to ask forgiveness than permission
109+
quant_handler = ao_quantizer_class_dict[quantizer](
110+
groupsize=q_kwargs["groupsize"], device=device, precision=precision
111+
)
103112
except TypeError as e:
104113
if "unexpected keyword argument 'device'" in str(e):
105114
quant_handler = ao_quantizer_class_dict[quantizer](
@@ -861,3 +870,33 @@ def quantized_model(self) -> nn.Module:
861870
"linear:int4": Int4WeightOnlyQuantizer,
862871
"linear:a8w4dq": Int8DynActInt4WeightQuantizer,
863872
}
873+
874+
try:
875+
import importlib.util
876+
import sys
877+
import os
878+
torchao_build_path = f"{os.getcwd()}/torchao-build"
879+
880+
# Try loading quantizer
881+
torchao_experimental_quant_api_spec = importlib.util.spec_from_file_location(
882+
"torchao_experimental_quant_api",
883+
f"{torchao_build_path}/src/ao/torchao/experimental/quant_api.py",
884+
)
885+
torchao_experimental_quant_api = importlib.util.module_from_spec(torchao_experimental_quant_api_spec)
886+
sys.modules["torchao_experimental_quant_api"] = torchao_experimental_quant_api
887+
torchao_experimental_quant_api_spec.loader.exec_module(torchao_experimental_quant_api)
888+
from torchao_experimental_quant_api import Int8DynActIntxWeightQuantizer
889+
ao_quantizer_class_dict["linear:a8wxdq"] = Int8DynActIntxWeightQuantizer
890+
891+
# Try loading custom op
892+
try:
893+
import glob
894+
libs = glob.glob(f"{torchao_build_path}/cmake-out/lib/liblinear_a8wxdq_ATEN.*")
895+
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
896+
torch.ops.load_library(libs[0])
897+
except Exception as e:
898+
print("Failed to torchao ops library with error: ", e)
899+
print("Slow fallback kernels will be used.")
900+
901+
except Exception as e:
902+
print(f"Failed to load torchao experimental a8wxdq quantizer with error: {e}")

torchchat/utils/scripts/build_native.sh

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ if [ $# -eq 0 ]; then
2626
exit 1
2727
fi
2828

29+
LINK_TORCHAO_OPS=OFF
2930
while (( "$#" )); do
3031
case "$1" in
3132
-h|--help)
@@ -42,6 +43,11 @@ while (( "$#" )); do
4243
TARGET="et"
4344
shift
4445
;;
46+
link_torchao_ops)
47+
echo "Linking with torchao ops..."
48+
LINK_TORCHAO_OPS=ON
49+
shift
50+
;;
4551
*)
4652
echo "Invalid option: $1"
4753
show_help
@@ -66,14 +72,28 @@ if [[ "$TARGET" == "et" ]]; then
6672
echo "Make sure you run install_executorch_libs"
6773
exit 1
6874
fi
75+
76+
if [[ "$LINK_TORCHAO_OPS" == "ON" ]]; then
77+
if [ ! -d "${TORCHCHAT_ROOT}/torchao-build" ]; then
78+
echo "Directory ${TORCHCHAT_ROOT}/torchao-build does not exist."
79+
echo "Make sure you run clone_torchao"
80+
exit 1
81+
fi
82+
83+
source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"
84+
find_cmake_prefix_path
85+
EXECUTORCH_INCLUDE_DIRS="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/include;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/src"
86+
EXECUTORCH_LIBRARIES="${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libexecutorch_no_prim_ops.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libextension_threadpool.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libcpuinfo.a;${TORCHCHAT_ROOT}/${ET_BUILD_DIR}/install/lib/libpthreadpool.a"
87+
install_torchao_executorch_ops
88+
fi
6989
fi
7090
popd
7191

7292
# CMake commands
7393
if [[ "$TARGET" == "et" ]]; then
74-
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DET_USE_ADAPTIVE_THREADS=ON -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja
94+
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_OPS="${LINK_TORCHAO_OPS}" -DET_USE_ADAPTIVE_THREADS=ON -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=1" -G Ninja
7595
else
76-
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja
96+
cmake -S . -B ./cmake-out -DCMAKE_PREFIX_PATH=`python3 -c 'import torch;print(torch.utils.cmake_prefix_path)'` -DLINK_TORCHAO_OPS="${LINK_TORCHAO_OPS}" -DCMAKE_CXX_FLAGS="-D_GLIBCXX_USE_CXX11_ABI=0" -G Ninja
7797
fi
7898
cmake --build ./cmake-out --target "${TARGET}"_run
7999

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
9+
10+
source "$(dirname "${BASH_SOURCE[0]}")/install_utils.sh"
11+
12+
pushd ${TORCHCHAT_ROOT}
13+
find_cmake_prefix_path
14+
clone_torchao
15+
install_torchao_aten_ops
16+
popd

torchchat/utils/scripts/install_et.sh

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,4 @@ pushd ${TORCHCHAT_ROOT}
1919
find_cmake_prefix_path
2020
clone_executorch
2121
install_executorch_libs $ENABLE_ET_PYBIND
22-
install_executorch_python_libs $ENABLE_ET_PYBIND
23-
# TODO: figure out the root cause of 'AttributeError: module 'evaluate'
24-
# has no attribute 'utils'' error from evaluate CI jobs and remove
25-
# `import lm_eval` from torchchat.py since it requires a specific version
26-
# of numpy.
27-
pip install numpy=='1.26.4'
2822
popd

0 commit comments

Comments
 (0)