diff --git a/torch/onnx/_internal/fx/patcher.py b/torch/onnx/_internal/fx/patcher.py index e2262bba4f7a5..ee919eae00d1a 100644 --- a/torch/onnx/_internal/fx/patcher.py +++ b/torch/onnx/_internal/fx/patcher.py @@ -1,19 +1,24 @@ import copy +import functools import io from typing import List, Union import torch + # TODO: Remove after https://github.com/huggingface/safetensors/pull/318 -try: - # safetensors is not an exporter requirement, but needed for some huggingface models - import safetensors # type: ignore[import] # noqa: F401 - import transformers # type: ignore[import] - from safetensors import torch as safetensors_torch # noqa: F401 +@functools.lru_cache(None) +def has_safetensors_and_transformers(): + try: + # safetensors is not an exporter requirement, but needed for some huggingface models + import safetensors # type: ignore[import] # noqa: F401 + import transformers # type: ignore[import] # noqa: F401 + + from safetensors import torch as safetensors_torch # noqa: F401 - has_safetensors_and_transformers = True -except ImportError: - has_safetensors_and_transformers = False + return True + except ImportError: + return False class ONNXTorchPatcher: @@ -61,7 +66,9 @@ def torch_load_wrapper(f, *args, **kwargs): # Wrapper or modified version of torch functions. self.torch_load_wrapper = torch_load_wrapper - if has_safetensors_and_transformers: + if has_safetensors_and_transformers(): + import safetensors + import transformers def safetensors_load_file_wrapper(filename, device="cpu"): # Record path for later serialization into ONNX proto @@ -109,7 +116,10 @@ def __enter__(self): desired_wrapped_methods.append((torch.Tensor, "__getitem__")) torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods - if has_safetensors_and_transformers: + if has_safetensors_and_transformers(): + import safetensors + import transformers + safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper transformers.modeling_utils.safe_load_file = ( self.safetensors_torch_load_file_wrapper @@ -120,7 +130,10 @@ def __exit__(self, exc_type, exc_value, traceback): torch.fx._symbolic_trace._wrapped_methods_to_patch = ( self.torch_fx__symbolic_trace__wrapped_methods_to_patch ) - if has_safetensors_and_transformers: + if has_safetensors_and_transformers(): + import safetensors + import transformers + safetensors.torch.load_file = self.safetensors_torch_load_file transformers.modeling_utils.safe_load_file = ( self.transformers_modeling_utils_safe_load_file