Skip to content

Commit

Permalink
Support PeftModel signature inspect (huggingface#27865)
Browse files Browse the repository at this point in the history
* Support PeftModel signature inspect

* Use get_base_model() to get the base model

---------

Co-authored-by: shujunhua1 <shujunhua1@jd.com>
  • Loading branch information
dancingpipi and shujunhua1 committed Dec 11, 2023
1 parent 3547818 commit e5079b0
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,10 @@ def _move_model_to_device(self, model, device):
def _set_signature_columns_if_needed(self):
if self._signature_columns is None:
# Inspect model forward signature to keep only the arguments it accepts.
signature = inspect.signature(self.model.forward)
model_to_inspect = self.model
if is_peft_available() and isinstance(self.model, PeftModel):
model_to_inspect = self.model.get_base_model()
signature = inspect.signature(model_to_inspect.forward)
self._signature_columns = list(signature.parameters.keys())
# Labels may be named label or label_ids, the default data collator handles that.
self._signature_columns += list(set(["label", "label_ids"] + self.label_names))
Expand Down

0 comments on commit e5079b0

Please sign in to comment.