Skip to content

[Metal] FP8 storage-only emulation (uchar storage + LUT decode helpers)#38

Open
apstenku123 wants to merge 1 commit intotile-ai:tilelang_mainfrom
apstenku123:cppmega/metal-fp8-storage-only
Open

[Metal] FP8 storage-only emulation (uchar storage + LUT decode helpers)#38
apstenku123 wants to merge 1 commit intotile-ai:tilelang_mainfrom
apstenku123:cppmega/metal-fp8-storage-only

Conversation

@apstenku123
Copy link
Copy Markdown

@apstenku123 apstenku123 commented May 4, 2026

Summary

This is the TVM-mirror half of a 2-PR pair adding storage-only FP8 emulation to the Metal codegen. The companion TileLang PR lives in tile-ai/tilelang (link below) and mirrors the same change to TileLang's CodeGenTileLangMetal specialisation at src/target/codegen_metal.{cc,h}.

Apple Silicon (M1-M4 / M5 NAX) has no native FP8 ALU support, MSL has no float8 scalar type, and simdgroup_matrix<uchar, ...> fails the Metal stdlib element-type assertion. The only viable representation is "storage-only FP8": pack 8-bit values in uchar / ucharN buffers, dequantize on load to half, do the math in half (or float accumulator), and quantize back on store. This mirrors how mxfp8/nvfp4 are realised in MLX core, and how stock TVM's CUDA codegen handles pre-sm89 hardware (src/target/source/codegen_cuda.cc:410-417).

Without this patch, any float8_e4m3 / float8_e5m2 / float8_e8m0fnu dtype reaching CodeGenMetal::PrintType triggers LOG(FATAL) << "Cannot convert type " << t << " to Metal type", which means sparse-MLA FP8, blockscaled, and mxfp8 lowering on the Metal target are unreachable. CUDA codegen has had this path for years; Metal had no equivalent.

What this PR adds

In src/target/source/codegen_metal.{cc,h}:

  1. PrintType FP8 case — when t.is_float8(), emit uchar for lanes==1, ucharN for lanes∈[2,4], uint2 for lanes==8, uint4 for lanes==16. Sets enable_fp8_=true so the prelude is emitted. Mirrors the CUDA codegen's behaviour where FP8 vectors >4 are packed into wider integer storage.
  2. PrintFP8Prelude — inline MSL helpers __tvm_fp8_e4m3_to_half, __tvm_fp8_e5m2_to_half, __tvm_half_to_fp8_e4m3, __tvm_half_to_fp8_e5m2. Encodings follow the OCP "OFP8 Formats for Deep Learning" v1.0 spec. E4M3 uses the finite-only encoding (S.1111.111 is NaN, no Inf); E5M2 uses IEEE-style with NaN/Inf. Both directions implement round-to-nearest-even on discarded mantissa bits.
  3. VisitExpr_(CastNode) override — when either side is FP8, scalar casts route through the helpers. Vector casts (lanes>1) raise a clear LOG(FATAL) directing the caller to scalarise — the TVM tir.transform.legalize_fp8 pass already scalarises most user FP8 casts, so this branch is rarely hit.
  4. Finish() override — if enable_fp8_ was set, splice the prelude right after using namespace metal; so the helpers see the MSL namespace.

The stock TVM Metal codegen path (target.build.metal) goes through legalize_fp8 first, which expands FP8 ops into bit-shuffle code inline. With this patch its PrintType no longer faults, so the legalised output also compiles with xcrun metal -c.

Motivation: cppmega.mlx workaround

In the cppmega.mlx port (cppmega_mlx/nn/_tilelang/fp8_msl_kernels.py) we currently ship audiohacking-style FP8 MSL kernels as raw mx.fast.metal_kernel(source=...) strings to bypass this codegen FATAL. With this PR landed in the vendored TVM mirror and bumped into TileLang, those workaround kernels can be replaced by ordinary T.cast(half, fp8_load) chains lowered through target.build.metal.

Companion PR

Stack

This PR targets tile-ai/tvm:tilelang_main and is rebased on HEAD 0e15b274 (the SHA TileLang's 3rdparty/tvm submodule pins).

Diff stat

 src/target/source/codegen_metal.cc | 203 +++++++++++++++++++++++++++++
 src/target/source/codegen_metal.h  |  10 +
 2 files changed, 213 insertions(+)

Test plan

  • git apply --check clean against TileLang/tvm@0e15b274
  • All four scalar cast directions lower cleanly through lower(prim_func, target=tvm.target.Target("metal")):
    • fp8_e4m3 -> half
    • half -> fp8_e4m3
    • fp8_e5m2 -> half
    • half -> fp8_e5m2
  • Generated MSL compiles via xcrun --sdk macosx metal -c against any prim_func with FP8 dtype lowered to MSL with the inline helpers
  • E4M3 subnormal correctness verified byte-by-byte against PyTorch torch.float8_e4m3fn and mlx.from_fp8 reference (full 256-byte e4m3 finite range)
  • mxfp8 (float8_e8m0fnu scale storage) lowers correctly: it's a device uchar* buffer with no helper calls (just pass-through)
  • Reviewer build: TVM cmake -DUSE_METAL=ON && ninja against the patched src/target/source/codegen_metal.cc

Risk

  • Limited to the Metal codegen. CUDA, ROCm, OpenCL, Vulkan, WebGPU, and CPU codegen are unaffected.
  • Vector FP8 casts (lanes>1) raise a clear FATAL with guidance; callers route through the existing legalize_fp8 scalarise pass.
  • Helpers emitted unconditionally if any FP8 dtype is referenced; Apple's Metal compiler dead-strips unused ones (~80 lines of inlined IR per kernel which is negligible).

Copilot AI review requested due to automatic review settings May 4, 2026 10:07
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot wasn't able to review this pull request because it exceeds the maximum number of files (300). Try reducing the number of changed files and requesting a review from Copilot again.

@apstenku123 apstenku123 changed the base branch from main to tilelang_main May 4, 2026 10:07
apstenku123 added a commit to DatasunriseOU/cppmega_mlx that referenced this pull request May 4, 2026
…, #37/#38/#39)

Three parallel agents completed the supermodule/submodule split filing:

1. tilelang_metal_fp8 (storage-only FP8 emulation) split:
   - 0001-tilelang-metal-fp8-storage-only.patch — supermodule half (235 lines)
   - 0002-tvm-metal-fp8-storage-only.patch — TVM-mirror half (260 lines, prefix stripped)
   - PR tile-ai/tilelang#2144 (supermodule, stacks on PR #2130)
   - PR tile-ai/tvm#38 (TVM mirror, base tilelang_main @ 0e15b274)

2. tilelang_metal_fp8_vector (vector cast lanes 2/3/4) split:
   - 0001-tilelang-metal-fp8-vector-cast.patch — supermodule half (148 lines)
   - 0002-tvm-metal-fp8-vector-cast.patch — TVM-mirror half (151 lines)
   - PR tile-ai/tilelang#2145 (supermodule, depends on #2144)
   - PR tile-ai/tvm#39 (TVM mirror, depends on #38)

3. PR #2143 TVM-mirror companion:
   - PR tile-ai/tvm#37 — already filed, README updated to link both halves

Total filed today: 11 PRs across 3 repos
- 1 ml-explore/mlx (#3476)
- 1 apache/tvm (#19504)
- 6 tile-ai/tilelang (#2139, #2140, #2141, #2142, #2143 super, #2144 super, #2145 super)
- 3 tile-ai/tvm (#37, #38, #39 — TVM-mirror companions)

PR #2142 (T.fp8_scaled_matmul) has no TVM-mirror companion needed —
verified the patch only touches supermodule files.

All splits round-trip clean (apply forward + reverse) on their respective
bases. README files in each docs/upstream/<dir>/ updated with PR URLs and
dependency-chain diagrams.

Note: TileLang/tvm redirects to tile-ai/tvm server-side (canonical org
slug). All TVM-mirror PRs land at tile-ai/tvm/pull/N URLs.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants