Skip to content

Commit

Permalink
Update on "Do not import transformers when import torch._dynamo"
Browse files Browse the repository at this point in the history

Fixes #123954



[ghstack-poisoned]
  • Loading branch information
soulitzer committed Apr 22, 2024
2 parents fe887a0 + 8c7d415 commit 27d40a8
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions torch/onnx/_internal/fx/patcher.py
@@ -1,17 +1,18 @@
import copy
import functools
import io
from typing import List, Union
import functools

import torch


# TODO: Remove after https://github.com/huggingface/safetensors/pull/318
@functools.cache
@functools.lru_cache(None)
def has_safetensors_and_transformers():
try:
# safetensors is not an exporter requirement, but needed for some huggingface models
import safetensors # type: ignore[import] # noqa: F401
import transformers # type: ignore[import]
import transformers # type: ignore[import] # noqa: F401

from safetensors import torch as safetensors_torch # noqa: F401

Expand Down Expand Up @@ -66,6 +67,8 @@ def torch_load_wrapper(f, *args, **kwargs):
self.torch_load_wrapper = torch_load_wrapper

if has_safetensors_and_transformers():
import safetensors
import transformers

def safetensors_load_file_wrapper(filename, device="cpu"):
# Record path for later serialization into ONNX proto
Expand Down Expand Up @@ -114,6 +117,9 @@ def __enter__(self):
torch.fx._symbolic_trace._wrapped_methods_to_patch = desired_wrapped_methods

if has_safetensors_and_transformers():
import safetensors
import transformers

safetensors.torch.load_file = self.safetensors_torch_load_file_wrapper
transformers.modeling_utils.safe_load_file = (
self.safetensors_torch_load_file_wrapper
Expand All @@ -125,6 +131,9 @@ def __exit__(self, exc_type, exc_value, traceback):
self.torch_fx__symbolic_trace__wrapped_methods_to_patch
)
if has_safetensors_and_transformers():
import safetensors
import transformers

safetensors.torch.load_file = self.safetensors_torch_load_file
transformers.modeling_utils.safe_load_file = (
self.transformers_modeling_utils_safe_load_file
Expand Down

0 comments on commit 27d40a8

Please sign in to comment.