diff --git a/torchao/utils.py b/torchao/utils.py index 02013c5197..e7ac4e5390 100644 --- a/torchao/utils.py +++ b/torchao/utils.py @@ -1153,7 +1153,10 @@ def is_package_at_least(package_name: str, min_version: str): def _is_fbgemm_gpu_genai_available(): # TODO: use is_package_at_least("fbgemm_gpu", "1.2.0") when # https://github.com/pytorch/FBGEMM/issues/4198 is fixed - if importlib.util.find_spec("fbgemm_gpu") is None: + if ( + importlib.util.find_spec("fbgemm_gpu") is None + or importlib.util.find_spec("fbgemm_gpu.experimental") is None + ): return False import fbgemm_gpu.experimental.gen_ai # noqa: F401