Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 37 additions & 17 deletions torchao/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,43 @@

logger = logging.getLogger(__name__)

try:
from pathlib import Path

so_files = list(Path(__file__).parent.glob("_C*.so"))
if len(so_files) > 0:
for file in so_files:
torch.ops.load_library(str(file))
from . import ops

# The following library contains CPU kernels from torchao/experimental
# They are built automatically by ao/setup.py if on an ARM machine.
# They can also be built outside of the torchao install process by
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
from torchao.experimental.op_lib import * # noqa: F403
except Exception as e:
logger.debug(f"Skipping import of cpp extensions: {e}")
skip_loading_so_files = False
# if torchao version has "+git", assume it's locally built and we don't know
# anything about the PyTorch version used to build it
# otherwise, assume it's prebuilt by torchao's build scripts and we can make
# assumptions about the PyTorch version used to build it.
if (not "+git" in __version__) and not ("unknown" in __version__):
# torchao v0.13.0 is built with PyTorch 2.8.0. We know that torchao .so
# files built using PyTorch 2.8.0 are not ABI compatible with PyTorch 2.9+.
# The following code skips importing the .so files if PyTorch 2.9+ is
# detected, to avoid crashing the Python process with "Aborted (core
# dumped)".
# TODO(#2901, and before next torchao release): make this generic for
# future torchao and torch versions
if __version__.startswith("0.13.0") and torch.__version__ > "2.8":
logger.warning(
f"Skipping import of cpp extensions due to incompatible torch version {torch.__version__} for torchao version {__version__}"
)
skip_loading_so_files = True

if not skip_loading_so_files:
try:
from pathlib import Path

so_files = list(Path(__file__).parent.glob("_C*.so"))
if len(so_files) > 0:
for file in so_files:
torch.ops.load_library(str(file))
from . import ops

# The following library contains CPU kernels from torchao/experimental
# They are built automatically by ao/setup.py if on an ARM machine.
# They can also be built outside of the torchao install process by
# running the script `torchao/experimental/build_torchao_ops.sh <aten|executorch>`
# For more information, see https://github.com/pytorch/ao/blob/main/torchao/experimental/docs/readme.md
from torchao.experimental.op_lib import * # noqa: F403
except Exception as e:
logger.warning(f"Skipping import of cpp extensions: {e}")

from torchao.quantization import (
autoquant,
Expand Down
Loading