Skip to content

Commit

Permalink
fix encoder hook (huggingface#25735)
Browse files Browse the repository at this point in the history
* fix encoder hook

* style
  • Loading branch information
SunMarc authored and parambharat committed Sep 26, 2023
1 parent f5dc0cc commit d0a1a65
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
MODEL_FOR_VISION_2_SEQ_MAPPING,
)
from ..utils import ExplicitEnum, ModelOutput, logging
from ..utils import ExplicitEnum, ModelOutput, is_accelerate_available, logging
from .beam_constraints import DisjunctiveConstraint, PhrasalConstraint
from .beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
from .configuration_utils import GenerationConfig
Expand Down Expand Up @@ -80,6 +80,9 @@

logger = logging.get_logger(__name__)

if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, add_hook_to_module


@dataclass
class GreedySearchDecoderOnlyOutput(ModelOutput):
Expand Down Expand Up @@ -631,8 +634,11 @@ def _prepare_encoder_decoder_kwargs_for_generation(
encoder = self.get_encoder()
# Compatibility with Accelerate big model inference: we need the encoder to outputs stuff on the same device
# as the inputs.
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
if hasattr(self, "hf_device_map"):
if hasattr(encoder, "_hf_hook"):
encoder._hf_hook.io_same_device = True
else:
add_hook_to_module(encoder, AlignDevicesHook(io_same_device=True))

# 2. Prepare encoder args and encoder kwargs from model kwargs.
irrelevant_prefix = ["decoder_", "cross_attn", "use_cache"]
Expand Down

0 comments on commit d0a1a65

Please sign in to comment.