use python version agnostic binding for mxfp8 cuda kernels#3471
use python version agnostic binding for mxfp8 cuda kernels#3471danielvegamyhre merged 3 commits intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3471
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 39db553 with merge base 08e5e20 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
spoke offline, we need to be using torchlib and not use pybind |
4e59508 to
78a8b79
Compare
| not is_cuda_version_at_least(12, 8), | ||
| reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels", | ||
| ) | ||
| def test_cuda_mx_dim1_invalid_block_size(): |
There was a problem hiding this comment.
deleting this test since block_size of 32 is hard coded in the python wrapper for the kernel now, since we always use this for mxfp8
78a8b79 to
e576e41
Compare
|
you should be able to use nm -D to also investigate the symbols and ensure there are none from python |
|
current CI failures will be resolved once this rollback in upstream pytorch is included in the next torch nightly: pytorch/pytorch#169985 |
|
other upstream pytorch issue causing CI issues has now been resolved: pytorch/pytorch#170184 |
|
mac-os test failure is unrelated |
* use py agnostic c++ extension for mxfp8_cuda * refactor mxfp8 cuda from pybind to torch_library api * put schema def inside guard
Summary
torchao/ops.pyfor other CUDA C++ extensionstorchao._C_mxfp8so file (lands in build/ dir) instead of separate torchao.prototype.mxpf8_cuda`extension (landed in torchao/prototype)Context
While doing the 0.15.0 torchao release and testing the test build for cuda 12.8, and i found the "torchao.prototype.mxfp8_cuda" c++ extension cannot be found (import error, module not found). we only build the extension for cuda 12.8+, so i checked the logs and i see logs indicating it was built: https://github.com/pytorch/ao/actions/runs/20046209265/job/57498462190
so then i checked the local installation itself, and i do see a .so file for the extension in the torchao/prototype dir, so it is definitely being built.
i tried asking claude about this and it says the build for python3.10 must match the python version in the conda env due to ABI incompatibility (i'm using python 3.12). as a test, i tried a fresh conda env with python 3.10, and instead of module not found, i get an undefined symbol error, so that does seem to indicate some python ABI issue.
asking @drisspg he said we should be building with a py agnostic flag, so i looked into this and we are doing this for other c++ extensions but not mxfp8_cuda, so I am fairly certain this is the root cause and this PR will fix the issue.