From e5079b0b2abcef11ecbdae60ba4a6636c57b725d Mon Sep 17 00:00:00 2001 From: dancingpipi Date: Tue, 12 Dec 2023 03:30:11 +0800 Subject: [PATCH] Support PeftModel signature inspect (#27865) * Support PeftModel signature inspect * Use get_base_model() to get the base model --------- Co-authored-by: shujunhua1 --- src/transformers/trainer.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 742bb3392986de..d6ccc4334dd46d 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -695,7 +695,10 @@ def _move_model_to_device(self, model, device): def _set_signature_columns_if_needed(self): if self._signature_columns is None: # Inspect model forward signature to keep only the arguments it accepts. - signature = inspect.signature(self.model.forward) + model_to_inspect = self.model + if is_peft_available() and isinstance(self.model, PeftModel): + model_to_inspect = self.model.get_base_model() + signature = inspect.signature(model_to_inspect.forward) self._signature_columns = list(signature.parameters.keys()) # Labels may be named label or label_ids, the default data collator handles that. self._signature_columns += list(set(["label", "label_ids"] + self.label_names))