Skip to content

Commit

Permalink
Fix Falcon HF warnings (run-llama#222)
Browse files Browse the repository at this point in the history
* Fix Falcon HF warnings.

* Remove commented line.
  • Loading branch information
rmitsch committed Jul 18, 2023
1 parent 32e6be5 commit 39a4750
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions spacy_llm/models/hf/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@ def __init__(
self._device: Optional[str] = None
super().__init__(name=name, config_init=config_init, config_run=config_run)
assert isinstance(self._tokenizer, transformers.PreTrainedTokenizerBase)
self._config_run["eos_token_id"] = self._tokenizer.eos_token_id
self._config_run["pad_token_id"] = self._tokenizer.pad_token_id
# Instantiate GenerationConfig object from config dict.
self._hf_config_run = transformers.GenerationConfig.from_pretrained(
self._name, **self._config_run
)

def init_model(self) -> Any:
self._tokenizer = transformers.AutoTokenizer.from_pretrained(self._name)
Expand All @@ -39,7 +43,8 @@ def hf_account(self) -> str:

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
return [
self._model(pr, **self._config_run)[0]["generated_text"] for pr in prompts
self._model(pr, generation_config=self._hf_config_run)[0]["generated_text"]
for pr in prompts
]

@staticmethod
Expand All @@ -52,10 +57,8 @@ def compile_default_configs() -> Tuple[Dict[str, Any], Dict[str, Any]]:
},
{
**default_cfg_run,
"max_length": 200,
"do_sample": True,
"top_k": 10,
"num_return_sequences": 1,
# In here because the original `max_length` id deprecated and triggers a warning.
"max_new_tokens": 200,
},
)

Expand Down

0 comments on commit 39a4750

Please sign in to comment.