Skip to content

Commit

Permalink
Do not import transformers when import torch._dynamo (pytorch#124634)
Browse files Browse the repository at this point in the history
Fixes pytorch#123954

Pull Request resolved: pytorch#124634
Approved by: https://github.com/thiagocrepaldi, https://github.com/Chillee
ghstack dependencies: pytorch#124343
  • Loading branch information
soulitzer authored and petrex committed May 3, 2024
1 parent 9b8580b commit 31bb9d1
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions torch/onnx/_internal/fx/patcher.py
@@ -1,19 +1,24 @@
import copy
import functools
import io
from typing import List, Union

import torch


# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
try:
# safetensors is not an exporter requirement, but needed for some huggingface models
import safetensors # type: ignore[import] # noqa: F401
import transformers # type: ignore[import]
from safetensors import torch as safetensors_torch # noqa: F401
@functools.lru_cache(None)
def has_safetensors_and_transformers():
try:
# safetensors is not an exporter requirement, but needed for some huggingface models
import safetensors # type: ignore[import] # noqa: F401
import transformers # type: ignore[import] # noqa: F401

from safetensors import torch as safetensors_torch # noqa: F401

has_safetensors_and_transformers = True
except ImportError:
has_safetensors_and_transformers = False
return True
except ImportError:
return False


class ONNXTorchPatcher:
Expand Down Expand Up @@ -61,7 +66,9 @@ def torch_load_wrapper(f, *args, **kwargs):
# Wrapper or modified version of torch functions.
self.torch_load_wrapper = torch_load_wrapper

if has_safetensors_and_transformers:
if has_safetensors_and_transformers():
import safetensors
import transformers

def safetensors_load_file_wrapper(filename, device="cpu"):
# Record path for later serialization into ONNX proto
Expand Down Expand Up @@ -109,7 +116,10 @@ def __enter__(self):
desired_wrapped_methods.append((torch.Tensor, "__getitem__"))
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods

if has_safetensors_and_transformers:
if has_safetensors_and_transformers():
import safetensors
import transformers

safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
transformers.modeling_utils.safe_load_file = (
self.safetensors_torch_load_file_wrapper
Expand All @@ -120,7 +130,10 @@ def __exit__(self, exc_type, exc_value, traceback):
torch.fx._symbolic_trace._wrapped_methods_to_patch = (
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
)
if has_safetensors_and_transformers:
if has_safetensors_and_transformers():
import safetensors
import transformers

safetensors.torch.load_file = self.safetensors_torch_load_file
transformers.modeling_utils.safe_load_file = (
self.transformers_modeling_utils_safe_load_file
Expand Down

0 comments on commit 31bb9d1

Please sign in to comment.