Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions .ci/scripts/export_model_artifact.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@ Arguments:
quant_name Quantization type (optional, default: non-quantized)
Options:
- non-quantized
- quantized-int4-tile-packed
- quantized-int4-weight-only
- quantized-int4-tile-packed (CUDA only)
- quantized-int4-weight-only (CUDA only)
- quantized-int4-metal (Metal only)

output_dir Output directory for artifacts (optional, default: current directory)

Examples:
export_model_artifact.sh metal "openai/whisper-small"
export_model_artifact.sh metal "nvidia/parakeet-tdt" "quantized-int4-metal"
export_model_artifact.sh cuda "mistralai/Voxtral-Mini-3B-2507" "quantized-int4-tile-packed"
export_model_artifact.sh cuda "google/gemma-3-4b-it" "non-quantized" "./output"
export_model_artifact.sh cuda "nvidia/parakeet-tdt" "non-quantized" "./output"
Expand Down Expand Up @@ -127,21 +129,28 @@ case "$QUANT_NAME" in
;;
quantized-int4-tile-packed)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear 4w --qlinear_encoder 4w --qlinear_packing_format tile_packed_to_4d --qlinear_encoder_packing_format tile_packed_to_4d"
;;
quantized-int4-weight-only)
if [ "$DEVICE" = "metal" ]; then
echo "Error: Metal backend does not yet support quantization '$QUANT_NAME'"
echo "Error: Metal backend does not support quantization '$QUANT_NAME'"
exit 1
fi
EXTRA_ARGS="--qlinear_encoder 4w"
;;
quantized-int4-metal)
if [ "$DEVICE" != "metal" ]; then
echo "Error: Quantization '$QUANT_NAME' only supported on Metal backend"
exit 1
fi
EXTRA_ARGS="--qlinear fpa4w --qlinear_encoder fpa4w"
;;
*)
echo "Error: Unsupported quantization '$QUANT_NAME'"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only"
echo "Supported quantizations: non-quantized, quantized-int4-tile-packed, quantized-int4-weight-only, quantized-int4-metal"
exit 1
;;
esac
Expand Down
12 changes: 12 additions & 0 deletions .github/workflows/metal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down Expand Up @@ -123,6 +129,12 @@ jobs:
name: "parakeet-tdt"
quant:
- "non-quantized"
# Only test int4 quantization with parakeet-tdt
include:
- model:
repo: "nvidia"
name: "parakeet-tdt"
quant: "quantized-int4-metal"
with:
runner: macos-m2-stable
python-version: '3.11'
Expand Down
137 changes: 90 additions & 47 deletions backends/apple/metal/passes/decompose_linear_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeLinearPass(ExportPass):
Expand All @@ -20,49 +20,92 @@ class DecomposeLinearPass(ExportPass):
then squeeze back to 2D.
"""

def call_operator(self, op, args, kwargs, meta):
# Only intercept linear operations
if op not in (exir_ops.edge.aten.linear.default, torch.ops.aten.linear.default):
return super().call_operator(op, args, kwargs, meta)

# Get input, weight, and bias arguments
input_arg = args[0]
weight_arg = args[1]
bias_arg = args[2] if len(args) > 2 else None

# Determine which ops to use based on the input operator
if op == exir_ops.edge.aten.linear.default:
t_op = exir_ops.edge.aten.t.default
matmul_op = exir_ops.edge.aten.matmul.default
add_op = exir_ops.edge.aten.add.Tensor
unsqueeze_op = exir_ops.edge.aten.unsqueeze.default
squeeze_op = exir_ops.edge.aten.squeeze.dims
else:
t_op = torch.ops.aten.t.default
matmul_op = torch.ops.aten.matmul.default
add_op = torch.ops.aten.add.Tensor
unsqueeze_op = torch.ops.aten.unsqueeze.default
squeeze_op = torch.ops.aten.squeeze.dims

# Check if input is 2D from metadata
needs_unsqueeze = len(meta["val"].shape) == 2

# Unsqueeze 2D input to 3D: (M, K) -> (1, M, K)
if needs_unsqueeze:
input_arg = super().call_operator(unsqueeze_op, (input_arg, 0), {}, meta)

# Transpose weight
weight_t = super().call_operator(t_op, (weight_arg,), {}, meta)

# Matmul
result = super().call_operator(matmul_op, (input_arg, weight_t), {}, meta)

# Add bias if present
if bias_arg is not None:
result = super().call_operator(add_op, (result, bias_arg), {}, meta)

# Squeeze 3D output back to 2D: (1, M, N) -> (M, N)
if needs_unsqueeze:
result = super().call_operator(squeeze_op, (result, [0]), {}, meta)

return result
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
modified = False
graph = graph_module.graph

for node in graph.nodes:
# Check if this is a linear operation
is_linear = False

if node.op == "call_function":
# Match both edge dialect and core aten linear operators
if node.target == exir_ops.edge.aten.linear.default:
is_linear = True
elif node.target == torch.ops.aten.linear.default:
is_linear = True

if is_linear:
# Get input, weight, and bias arguments
input_node = node.args[0]
weight_node = node.args[1]
bias_node = node.args[2] if len(node.args) > 2 else None

with graph.inserting_before(node):
# Determine which ops to use based on the input operator
target_str = str(node.target)

if "executorch_exir_dialects_edge" in target_str:
# Use edge dialect operators
t_op = exir_ops.edge.aten.t.default
matmul_op = exir_ops.edge.aten.matmul.default
add_op = exir_ops.edge.aten.add.Tensor
unsqueeze_op = exir_ops.edge.aten.unsqueeze.default
squeeze_op = exir_ops.edge.aten.squeeze.dims
else:
# Use core aten operators
t_op = torch.ops.aten.t.default
matmul_op = torch.ops.aten.matmul.default
add_op = torch.ops.aten.add.Tensor
unsqueeze_op = torch.ops.aten.unsqueeze.default
squeeze_op = torch.ops.aten.squeeze.dims

# Check if input is 2D
needs_unsqueeze = False
if hasattr(input_node, "meta") and "val" in input_node.meta:
if len(input_node.meta["val"].shape) == 2:
needs_unsqueeze = True

# Unsqueeze 2D input to 3D: (M, K) -> (1, M, K)
current_input = input_node
if needs_unsqueeze:
current_input = graph.call_function(
unsqueeze_op,
args=(input_node, 0),
)

# Decompose linear: matmul(input, weight.T) + bias
weight_t = graph.call_function(
t_op,
args=(weight_node,),
)

matmul_result = graph.call_function(
matmul_op,
args=(current_input, weight_t),
)

if bias_node is not None:
result = graph.call_function(
add_op,
args=(matmul_result, bias_node),
)
else:
result = matmul_result

# Squeeze 3D output back to 2D: (1, M, N) -> (M, N)
if needs_unsqueeze:
result = graph.call_function(
squeeze_op,
args=(result, [0]),
)

# Replace all uses of the linear node with the decomposed result
node.replace_all_uses_with(result)
graph.erase_node(node)
modified = True

if modified:
graph_module.recompile()

return PassResult(graph_module, modified)
23 changes: 19 additions & 4 deletions backends/apple/metal/tests/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,15 @@
from torch.export import export
from torch.nn.attention import SDPBackend

# Need to import to load the ops
import torchao.experimental.ops.mps # noqa: F401
from torchao.experimental.quant_api import UIntxWeightOnlyConfig
from torchao.quantization.quant_api import quantize_
try:
# Need to import to load the ops
import torchao.experimental.ops.mps # noqa: F401
from torchao.experimental.quant_api import UIntxWeightOnlyConfig
from torchao.quantization.quant_api import quantize_

TORCHAO_AVAILABLE = True
except ImportError:
TORCHAO_AVAILABLE = False


# Check if MPS is available for export tests
Expand Down Expand Up @@ -241,6 +246,7 @@ def forward(self, x: torch.Tensor):
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


Expand All @@ -265,6 +271,7 @@ def forward(self, x: torch.Tensor):
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


Expand All @@ -289,6 +296,7 @@ def forward(self, x: torch.Tensor):
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


Expand All @@ -313,6 +321,7 @@ def forward(self, x: torch.Tensor):
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


Expand All @@ -337,6 +346,7 @@ def forward(self, x: torch.Tensor):
"rtol_float32": 5e-2,
"atol_bfloat16": 1e-1,
"rtol_bfloat16": 1e-1,
"skip": not TORCHAO_AVAILABLE,
}


Expand Down Expand Up @@ -688,6 +698,11 @@ def quantize_model(model: nn.Module, qlinear: str, qlinear_group_size: int = 32)
- "fpa4w": Floating point activation, 4-bit weight (Metal backend)
qlinear_group_size: Group size for quantization (default: 32).
"""
if not TORCHAO_AVAILABLE:
raise RuntimeError(
"torchao is not available. Install torchao to use quantization."
)

if qlinear == "fpa4w":
linear_config = UIntxWeightOnlyConfig(
group_size=qlinear_group_size,
Expand Down
31 changes: 22 additions & 9 deletions examples/models/parakeet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,26 @@ The export script supports quantizing encoder and decoder linear layers using [t

| Argument | Description |
|----------|-------------|
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear_encoder` | Quantization config for encoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_encoder_group_size` | Group size for encoder linear quantization (default: 32) |
| `--qlinear_encoder_packing_format` | Packing format for encoder: `tile_packed_to_4d` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w` |
| `--qlinear` | Quantization config for decoder linear layers: `4w`, `8w`, `8da4w`, `8da8w`, `fpa4w` |
| `--qlinear_group_size` | Group size for decoder linear quantization (default: 32) |
| `--qlinear_packing_format` | Packing format for decoder: `tile_packed_to_4d` |
| `--qembedding` | Quantization config for decoder embedding layer: `4w`, `8w` |
| `--qembedding_group_size` | Group size for embedding quantization (default: 0 = per-axis) |

#### Quantization Configs

| Config | Description |
|--------|-------------|
| `4w` | 4-bit weight only quantization |
| `8w` | 8-bit weight only quantization |
| `8da4w` | 8-bit dynamic activation, 4-bit weight |
| `8da8w` | 8-bit dynamic activation, 8-bit weight |
| Config | Description | Backends |
|--------|-------------|----------|
| `4w` | 4-bit weight only quantization | CUDA |
| `8w` | 8-bit weight only quantization | CUDA |
| `8da4w` | 8-bit dynamic activation, 4-bit weight | CUDA |
| `8da8w` | 8-bit dynamic activation, 8-bit weight | CUDA |
| `fpa4w` | Floating point activation, 4-bit weight | Metal |

#### Example: 4-bit Weight Quantization with Tile Packing
#### Example: 4-bit Weight Quantization with Tile Packing (CUDA)

```bash
python export_parakeet_tdt.py \
Expand All @@ -74,6 +75,18 @@ python export_parakeet_tdt.py \

**Note:** The `tile_packed_to_4d` packing format is optimized for CUDA.

#### Example: Metal 4-bit Quantization

```bash
python export_parakeet_tdt.py \
--backend metal \
--qlinear_encoder fpa4w \
--qlinear_encoder_group_size 32 \
--qlinear fpa4w \
--qlinear_group_size 32 \
--output-dir ./parakeet_metal_quantized
```

### Metal Export (macOS)

```bash
Expand Down
10 changes: 8 additions & 2 deletions examples/models/parakeet/export_parakeet_tdt.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def main():
parser.add_argument(
"--qlinear",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for decoder linear layers",
)
parser.add_argument(
Expand All @@ -603,7 +603,7 @@ def main():
parser.add_argument(
"--qlinear_encoder",
type=str,
choices=["4w", "8w", "8da4w", "8da8w"],
choices=["4w", "8w", "8da4w", "8da8w", "fpa4w"],
help="Quantization config for encoder linear layers",
)
parser.add_argument(
Expand Down Expand Up @@ -639,6 +639,12 @@ def main():
if args.dtype == "fp16":
parser.error("fp16 is not yet supported")

# Validate fpa4w quantization requires Metal backend
if args.qlinear == "fpa4w" and args.backend != "metal":
parser.error("--qlinear=fpa4w can only be used with --backend=metal")
if args.qlinear_encoder == "fpa4w" and args.backend != "metal":
parser.error("--qlinear_encoder=fpa4w can only be used with --backend=metal")

os.makedirs(args.output_dir, exist_ok=True)

print("Extracting tokenizer...")
Expand Down
Loading
Loading