Support addmm and split_copy in the Metal (AOTI) backend#19924
Support addmm and split_copy in the Metal (AOTI) backend#19924abdelaziz-mahdy wants to merge 2 commits into
Conversation
The experimental Metal backend could not lower common CNNs (MobileNet, YOLO) because two ops hit unsupported AOTI fallback kernels: - aten::split_copy.Tensor fell back to the proxy executor, which the AOTI runtime does not support. - aoti_torch_mps_addmm_out is emitted by inductor's mm+bias fusion (torch/_inductor/fx_passes/post_grad.py) during MPS codegen, but the libtorch-free runtime had no shim for it. Graph-level decomposition is insufficient because inductor re-fuses mm+bias back into addmm (and folds the size-1 unsqueeze that DecomposeLinearPass inserts for batch=1). Changes: - Map split_copy.Tensor -> split.Tensor in ReplaceViewCopyWithViewPass so inductor codegens it as views (like the existing slice_copy/select_copy). - Implement aoti_torch_mps_addmm_out (op_addmm.mm) via MPSGraph, mirroring op_mm.mm, and allow-list it in get_supported_fallback_kernels. - Add regression modules (addmm, split_cat, batch-1 linear) to test_modules. Verified end-to-end with executor_runner on macOS arm64: MobileNetV3-small, plus addmm / split_cat / linear(batch=1) numerics matching eager.
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19924
Note: Links to docs will display an error until the docs builds have been completed.
|
There was a problem hiding this comment.
Pull request overview
Note
Copilot was unable to run its full agentic suite in this review.
Adds Metal backend support for aten.addmm via a new c-shim that wraps an MPSGraph-cached addmm implementation, registers the kernel as a supported fallback, and extends the view-copy replacement pass to also rewrite split_copy (which has no c-shim) into the native split view op. Tests are added for both addmm and split+cat module patterns.
Changes:
- New
aoti_torch_mps_addmm_outMetal shim with graph caching, transposed-mat2 detection, and bias broadcasting. - Registration of the addmm shim in the backend's supported fallback kernel list and CMake build.
- Extended
replace_view_copy_with_viewto mapsplit_copy→split, plus newAddmm,LinearWithBias (batch=1), andSplitCattest modules.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| backends/apple/metal/runtime/ops/op_addmm.mm | New MPSGraph-backed addmm c-shim implementation. |
| backends/apple/metal/metal_backend.py | Registers aoti_torch_mps_addmm_out as a supported fallback kernel. |
| backends/apple/metal/CMakeLists.txt | Adds op_addmm.mm to the Metal AOTI sources. |
| backends/aoti/passes/replace_view_copy_with_view.py | Maps aten.split_copy.Tensor to aten.split.Tensor (both torch and edge variants). |
| backends/apple/metal/tests/test_modules.py | Adds Addmm, SplitCat, and linear_bias_batch1 test registry entries. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // Bias may be 1-D [N] or 2-D [M, N]; feed its physical shape and rely on | ||
| // MPSGraph broadcasting in the addition. | ||
| NSMutableArray<NSNumber*>* biasShape = [NSMutableArray array]; | ||
| for (size_t i = 0; i < static_cast<size_t>(bias_tensor->dim()); ++i) { | ||
| [biasShape addObject:@(bias_tensor->sizes()[i])]; | ||
| } |
There was a problem hiding this comment.
Good catch — now keying on the full bias shape (rank + each dim) instead of just the rank. Fixed in 7dc3552.
| @try { | ||
| stream->executeMPSGraph(mpsGraph, feeds, results, SyncType::COMMIT); | ||
| } @catch (NSException* exception) { |
| throw std::runtime_error("MPSGraph execution failed with NSException"); | ||
| } | ||
|
|
||
| [mat1Data release]; | ||
| [mat2Data release]; | ||
| [biasData release]; | ||
| [outputData release]; |
| // Validate matmul operand dimensions. | ||
| if (mat1_tensor->dim() != 2 || mat2_tensor->dim() != 2) { |
There was a problem hiding this comment.
Left as-is for parity with op_mm (AOTInductor always allocates out at the right shape), but I added the dtype check across operands which covers the most likely mismatch here.
| int32_t dtype = static_cast<int32_t>(mat1_tensor->scalar_type()); | ||
| MPSDataType mps_dtype; | ||
| if (dtype == static_cast<int32_t>(SupportedDTypes::FLOAT32)) { | ||
| mps_dtype = MPSDataTypeFloat32; | ||
| } else if (dtype == static_cast<int32_t>(SupportedDTypes::BFLOAT16)) { | ||
| mps_dtype = MPSDataTypeBFloat16; | ||
| } else { | ||
| ET_LOG(Error, "aoti_torch_mps_addmm_out: unsupported dtype %d", dtype); | ||
| throw std::runtime_error("Unsupported data type for addmm"); | ||
| } |
There was a problem hiding this comment.
Added — verifies mat2/self/out all match mat1 dtype before building the graph. 7dc3552
| if (mat1_tensor->dim() != 2 || mat2_tensor->dim() != 2) { | ||
| std::string error_msg = | ||
| "aoti_torch_mps_addmm_out: mat1/mat2 must be 2-D, got " + | ||
| std::to_string(mat1_tensor->dim()) + " and " + | ||
| std::to_string(mat2_tensor->dim()); | ||
| ET_LOG(Error, "%s", error_msg.c_str()); | ||
| throw std::runtime_error(error_msg); | ||
| } |
There was a problem hiding this comment.
Kept consistent with the sibling shims (op_mm/op_bmm validate the same way) so addmm does not diverge.
This PR needs a
|
- Cache key: include the full bias shape (rank + each dim), not just the rank, so equal-rank but differently-shaped biases (e.g. [N] vs [1], or [M, N] vs [1, N]) don't collide and reuse a graph whose biasPlaceholder has the wrong shape. - Release the MPSGraphTensorData objects in an @finally so they aren't leaked when executeMPSGraph throws. - Validate that mat2/self/out share mat1's dtype before building the graph (return InvalidArgument on mismatch) to avoid silently reinterpreting buffers.
Summary
The experimental Metal (AOTI) backend can't lower common CNNs (MobileNetV3, YOLO) — export fails with "missing fallback kernels". Fixes #19907.
Two ops hit unsupported AOTI fallbacks:
aten::split_copy.Tensorfalls back to the proxy executor, which the AOTI runtime doesn't support.aoti_torch_mps_addmm_outis emitted by inductor'smm + bias → addmmfusion (torch/_inductor/fx_passes/post_grad.py) during MPS codegen, but the libtorch-free runtime has no shim for it. Graph-level decomposition alone is insufficient: inductor re-fusesmm + biasback intoaddmm, and forbatch=1it folds the size-1unsqueezethatDecomposeLinearPassinserts, so the fusion fires anyway.Changes
ReplaceViewCopyWithViewPass: mapsplit_copy.Tensor → split.Tensor(core + edge dialect), mirroring the existingslice_copy/select_copyhandling, so inductor codegens it as views instead of a proxy-executor fallback.runtime/ops/op_addmm.mm: implementsaoti_torch_mps_addmm_out(out = beta·self + alpha·(mat1 @ mat2)) viaMPSGraph, mirroringop_mm.mm(transposed-mat2handling + graph cache; cache key includesbeta/alpha). Registered inCMakeLists.txt.MetalBackend.get_supported_fallback_kernels: allow-listaoti_torch_mps_addmm_out.tests/test_modules.py: addaddmm,split_cat, andlinear_bias_batch1(batch=1 → the MobileNet-classifier case) regression modules.Test plan
Built
executor_runnerwith-DEXECUTORCH_BUILD_METAL=ONon macOS arm64 (Apple silicon) and ran the exported.ptes (input = ones); runtime outputs match eager:addmm(batch=1)linear_bias_batch1(all 101 values)split_cat[1, 1000]AOT export of MobileNetV3-small / MobileNetV2 / a YOLO-style head no longer raises "missing fallback kernels".