Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ on:
paths:
- .github/workflows/mlx.yml
- backends/mlx/**
- extension/llm/export/**
workflow_dispatch:

permissions: {}
Expand Down Expand Up @@ -36,7 +37,7 @@ jobs:
${CONDA_RUN} pip list

echo "::group::Build test runners"
${CONDA_RUN} cmake --build cmake-out --target op_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
${CONDA_RUN} cmake --build cmake-out --target op_test_runner multi_thread_test_runner -j$(( $(sysctl -n hw.ncpu) - 1 ))
echo "::endgroup::"

echo "::group::Run op unit tests"
Expand All @@ -51,6 +52,14 @@ jobs:
-v
echo "::endgroup::"

echo "::group::Run multi-thread stress test"
${CONDA_RUN} python backends/mlx/test/export_multi_thread_test_model.py /tmp/multi_thread_test_model.pte
ET_TESTING_MODEL_PATH=/tmp/multi_thread_test_model.pte \
ET_TESTING_NUM_THREADS=50 \
ET_PREDICTIONS_PER_THREAD=100 \
./cmake-out/backends/mlx/test/multi_thread_test_runner
echo "::endgroup::"

backend-tester:
strategy:
fail-fast: false
Expand Down
122 changes: 107 additions & 15 deletions backends/mlx/builder/op_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
if TYPE_CHECKING:
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder

# When True, always serialize the biases tensor for quantized ops.
# When False, use init-time computation when zero_point is all zeros,
# computing biases = -scales * 2^(bits-1) during the init chain.
QUANTIZED_SERIALIZE_BIASES = False


def get_aten_target(target):
"""
Expand Down Expand Up @@ -168,6 +173,50 @@ def emit_lifted_constant(P: "MLXProgramBuilder", value, dtype: torch.dtype) -> S
return slot


def emit_quantized_biases(
P: "MLXProgramBuilder",
zero_point_key: str,
scale: torch.Tensor,
zero_point: torch.Tensor,
bits: int,
B: torch.Tensor,
scale_slot: "Slot",
) -> "Slot":
"""Emit biases for quantized ops, computing at init time when possible.

When zero_point is all zeros and QUANTIZED_SERIALIZE_BIASES is False,
avoids serializing the biases tensor by computing biases = scales * -offset
during the init chain instead.

Returns the biases Slot.
"""
from executorch.backends.mlx.serialization.mlx_graph_schema import MultiplyNode
from torch._subclasses.fake_tensor import FakeTensor

is_scale_only = False
if not isinstance(zero_point, FakeTensor):
if torch.sum(torch.abs(zero_point)).item() == 0:
is_scale_only = True

if QUANTIZED_SERIALIZE_BIASES or not is_scale_only:
return P.make_or_get_constant(f"{zero_point_key}_to_biases", B)

scale_dtype = scale.dtype
offset = 1 << (bits - 1)
neg_offset = emit_lifted_constant(P, -offset, scale_dtype)
biases = P.make_or_get_constant(
f"{zero_point_key}_to_biases_dummy", torch.tensor(0.0, dtype=B.dtype)
)
P.emit_init(
MultiplyNode(
a=P.slot_to_tid(scale_slot),
b=P.slot_to_tid(neg_offset),
out=P.slot_to_tid(biases),
)
)
return biases


def to_mlx_qparams(
qdata: torch.Tensor,
scale: torch.Tensor,
Expand All @@ -194,21 +243,36 @@ def to_mlx_qparams(
"""
assert qdata.dtype == torch.int8
offset = 2 ** (bits - 1)
Q = qdata.to(torch.int32) + offset

# Pack data tightly into uint32
assert 32 % bits == 0
vals_per_uint32 = 32 // bits
assert qdata.shape[1] % vals_per_uint32 == 0

Q = Q.reshape(-1, vals_per_uint32)
shifts = torch.arange(0, 32, bits, dtype=torch.int64)

# Convert to int64 for shift/packing
Q = Q.to(torch.int64)
Q = (Q << shifts).sum(dim=-1)
Q = Q.to(torch.uint32)
Q = Q.reshape(qdata.shape[0], -1)
rows, cols = qdata.shape

if bits == 4:
# 4-bit: view(uint8) + wrapping add + pack 2 nibbles per byte → view as uint32
q = qdata.view(torch.uint8) + offset
q3 = q.reshape(rows, cols // 2, 2)
Q = (q3[:, :, 0] | (q3[:, :, 1] << 4)).view(torch.uint32)
elif bits == 2:
# 2-bit: pack 4×2-bit values per byte in uint8, then view as uint32
Q = ((qdata.view(torch.uint8) + offset) & 0x3).reshape(rows, cols // 4, 4)
packed = Q[:, :, 0] | (Q[:, :, 1] << 2) | (Q[:, :, 2] << 4) | (Q[:, :, 3] << 6)
Q = packed.contiguous().view(torch.uint32)
elif bits == 8:
# 8-bit: each byte maps 1:1 to a uint32 slot — no shifting needed
q = qdata.view(torch.uint8) + offset
Q = q.contiguous().view(torch.uint32).reshape(rows, -1)
else:
# General fallback for other bit widths
Q = (qdata.to(torch.int32) + offset).reshape(-1, vals_per_uint32)
shifts = torch.arange(0, 32, bits, dtype=torch.int32)
shifted = Q << shifts
packed = shifted[:, 0]
for i in range(1, vals_per_uint32):
packed = packed | shifted[:, i]
Q = packed.view(torch.uint32).reshape(rows, -1)

if compute_biases:
B = -scale * (zero_point.to(scale.dtype) + offset)
Expand All @@ -217,6 +281,34 @@ def to_mlx_qparams(
return Q, None


def parse_dequant_nvfp4_node(
node: Node,
) -> Optional[Tuple[Node, Node, Node, torch.dtype]]:
"""Parse a torchao.dequantize_nvfp4 node.

Returns (qdata, scale, per_tensor_scale, output_dtype) or None if not a
dequantize_nvfp4 node or the custom op is not registered.
"""
target = get_aten_target(node.target)
try:
import executorch.extension.llm.export.nvfp4 # noqa: F401
except ImportError:
return None

if target is not torch.ops.torchao.dequantize_nvfp4.default:
return None

qdata, scale, per_tensor_scale = node.args[0:3]

output_dtype = torch.float32
if len(node.args) > 4:
output_dtype = node.args[4]
elif "output_dtype" in node.kwargs:
output_dtype = node.kwargs["output_dtype"]

return qdata, scale, per_tensor_scale, output_dtype


def parse_dequant_node(
node: Node,
) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]:
Expand Down Expand Up @@ -244,11 +336,11 @@ def parse_dequant_node(
quantized_dim, group_size = non_one[0]
if group_size not in [32, 64, 128]:
return None
if qmin == -8 and qmax == 7:
bits = 4
elif qmin == -128 and qmax == 127:
bits = 8
else:

# TODO: MLX supports 3, 5, and 7, but we need to figure out the
# packing story in to_mlx_qparams to use them
bits = (qmax - qmin + 1).bit_length() - 1
if bits not in [2, 4, 8]:
return None
return qdata, scale, zero_point, group_size, bits, out_dtype, quantized_dim

Expand Down
Loading
Loading