From ee420bc7abaecf511488df786292ddba42ae1d29 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Tue, 13 Jun 2023 17:36:26 +0800 Subject: [PATCH] fixed: https://github.com/shibing624/textgen/issues/41 --- textgen/bloom/bloom_model.py | 56 +++++++++++++++----- textgen/chatglm/chatglm_model.py | 89 ++++++++++++++++++++++++-------- textgen/llama/llama_model.py | 56 +++++++++++++++----- 3 files changed, 153 insertions(+), 48 deletions(-) diff --git a/textgen/bloom/bloom_model.py b/textgen/bloom/bloom_model.py index 6d5b1f2..d981977 100644 --- a/textgen/bloom/bloom_model.py +++ b/textgen/bloom/bloom_model.py @@ -459,16 +459,39 @@ def train_model( return global_step, training_loss @torch.inference_mode() - def predict(self, sentences: List[str], keep_prompt: bool = False, max_length: int = None, - add_system_prompt=False, **kwargs): + def predict( + self, + sentences: List[str], + keep_prompt: bool = False, + add_system_prompt=False, + max_length: int = 256, + temperature: float = 0.95, + top_p: float = 0.9, + top_k: int = 40, + do_sample: bool = True, + repetition_penalty: float = 1.3, + length_penalty: float = 2.0, + num_beams: int = 1, + num_return_sequences: int = 1, + **kwargs + ): """ Performs predictions on a list of text. Args: sentences: A python list of text (str) to be sent to the model for prediction. Note that the prefix should be prepended to the text. keep_prompt: Whether to keep the prompt in the generated text. + add_system_prompt: Whether to add the system prompt to the prompt text. max_length: The maximum length of the generated text. - add_system_prompt: Whether to add the system prompt to the prompt text. default: False + temperature: The value used to module the next token probabilities. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + do_sample: Whether or not to use sampling ; use greedy decoding otherwise. + repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: The parameter that penalizes longer sequences. + num_beams: The number of beams to use for beam search. 1 means no beam search. + num_return_sequences: The number of independently computed returned sequences for each element in the batch. + **kwargs: Additional arguments for generating sequences. Returns: preds: A python list of the generated sequences. @@ -495,16 +518,16 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, max_length: i inputs = self.tokenizer(batch, padding=True, return_tensors='pt').to(self.device) generation_config = GenerationConfig( max_new_tokens=max_length if max_length else self.args.max_length, - temperature=self.args.temperature, - top_p=self.args.top_p, - top_k=self.args.top_k, - do_sample=self.args.do_sample, - repetition_penalty=self.args.repetition_penalty, - length_penalty=self.args.length_penalty, - num_beams=self.args.num_beams, + temperature=temperature if temperature is not None else self.args.temperature, + top_p=top_p if top_p else self.args.top_p, + top_k=top_k if top_k else self.args.top_k, + do_sample=do_sample if do_sample is not None else self.args.do_sample, + repetition_penalty=repetition_penalty if repetition_penalty else self.args.repetition_penalty, + length_penalty=length_penalty if length_penalty else self.args.length_penalty, + num_beams=num_beams if num_beams else self.args.num_beams, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, - num_return_sequences=self.args.num_return_sequences, + num_return_sequences=num_return_sequences if num_return_sequences else self.args.num_return_sequences, return_dict_in_generate=True, output_scores=True, **kwargs, @@ -523,8 +546,15 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, max_length: i return all_outputs @torch.inference_mode() - def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: bool = False, - max_length: int = 2048, add_system_prompt=True, **kwargs): + def chat( + self, + query: str, + history: List[Tuple[str, str]] = None, + keep_prompt: bool = False, + add_system_prompt=True, + max_length: int = 2048, + **kwargs + ): """Chat with the model.""" if history is None: history = [] diff --git a/textgen/chatglm/chatglm_model.py b/textgen/chatglm/chatglm_model.py index a8413bf..134d553 100644 --- a/textgen/chatglm/chatglm_model.py +++ b/textgen/chatglm/chatglm_model.py @@ -462,16 +462,39 @@ def process_response(self, response): return response @torch.inference_mode() - def predict(self, sentences: List[str], keep_prompt: bool = False, - max_length: int = None, add_system_prompt=False, **kwargs): + def predict( + self, + sentences: List[str], + keep_prompt: bool = False, + add_system_prompt=False, + max_length: int = 256, + temperature: float = 0.95, + top_p: float = 0.7, + top_k: int = 40, + do_sample: bool = True, + repetition_penalty: float = 1.0, + length_penalty: float = 2.0, + num_beams: int = 1, + num_return_sequences: int = 1, + **kwargs + ) -> List[str]: """ Performs predictions on a list of text. Args: sentences: A python list of text (str) to be sent to the model for prediction. keep_prompt: Whether to keep the prompt in the generated text. - max_length: The maximum length of the generated text. add_system_prompt: Whether to add the system prompt to the prompt text. + max_length: The maximum length of the generated text. + temperature: The value used to module the next token probabilities. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + do_sample: Whether or not to use sampling ; use greedy decoding otherwise. + repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: The parameter that penalizes longer sequences. + num_beams: The number of beams to use for beam search. 1 means no beam search. + num_return_sequences: The number of independently computed returned sequences for each element in the batch. + **kwargs: Additional arguments for generating sequences. Returns: preds: A python list of the generated sequences. @@ -498,16 +521,16 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, inputs = self.tokenizer(batch, padding=True, return_tensors='pt').to(self.device) gen_kwargs = { "max_new_tokens": max_length if max_length else self.args.max_length, - "temperature": self.args.temperature, - "top_p": self.args.top_p, - "top_k": self.args.top_k, - "do_sample": self.args.do_sample, - "repetition_penalty": self.args.repetition_penalty, - "length_penalty": self.args.length_penalty, - "num_beams": self.args.num_beams, + "temperature": temperature if temperature is not None else self.args.temperature, + "top_p": top_p if top_p else self.args.top_p, + "top_k": top_k if top_k else self.args.top_k, + "do_sample": do_sample if do_sample is not None else self.args.do_sample, + "repetition_penalty": repetition_penalty if repetition_penalty else self.args.repetition_penalty, + "length_penalty": length_penalty if length_penalty else self.args.length_penalty, + "num_beams": num_beams if num_beams else self.args.num_beams, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, - "num_return_sequences": self.args.num_return_sequences, + "num_return_sequences": num_return_sequences if num_return_sequences else self.args.num_return_sequences, **kwargs } outputs = self.model.generate(**inputs, **gen_kwargs) @@ -525,8 +548,15 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, return all_outputs @torch.inference_mode() - def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: bool = False, - max_length: int = 2048, add_system_prompt=True, **kwargs): + def chat( + self, + query: str, + history: List[Tuple[str, str]] = None, + keep_prompt: bool = False, + add_system_prompt=True, + max_length: int = 2048, + **kwargs + ): """ Chat with the model :param query: @@ -553,22 +583,37 @@ def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: b return response, history @torch.inference_mode() - def stream_chat(self, query: str, history: List[Tuple[str, str]] = None, - max_length: int = 2048, add_system_prompt=True, **kwargs): + def stream_chat( + self, + query: str, + history: List[Tuple[str, str]] = None, + add_system_prompt=True, + max_length: int = 2048, + temperature: float = 0.95, + top_p: float = 0.7, + top_k: int = 40, + do_sample: bool = True, + repetition_penalty: float = 1.0, + length_penalty: float = 2.0, + num_beams: int = 1, + num_return_sequences: int = 1, + **kwargs + ): """Chat with the model in a streaming fashion""" if history is None: history = [] gen_kwargs = { "max_new_tokens": max_length if max_length else self.args.max_length, - "temperature": self.args.temperature, - "top_p": self.args.top_p, - "do_sample": self.args.do_sample, - "repetition_penalty": self.args.repetition_penalty, - "length_penalty": self.args.length_penalty, - "num_beams": self.args.num_beams, + "temperature": temperature if temperature is not None else self.args.temperature, + "top_p": top_p if top_p else self.args.top_p, + "top_k": top_k if top_k else self.args.top_k, + "do_sample": do_sample if do_sample is not None else self.args.do_sample, + "repetition_penalty": repetition_penalty if repetition_penalty else self.args.repetition_penalty, + "length_penalty": length_penalty if length_penalty else self.args.length_penalty, + "num_beams": num_beams if num_beams else self.args.num_beams, "eos_token_id": self.tokenizer.eos_token_id, "pad_token_id": self.tokenizer.pad_token_id, - "num_return_sequences": self.args.num_return_sequences, + "num_return_sequences": num_return_sequences if num_return_sequences else self.args.num_return_sequences, **kwargs } if not history: diff --git a/textgen/llama/llama_model.py b/textgen/llama/llama_model.py index c791862..af7728c 100644 --- a/textgen/llama/llama_model.py +++ b/textgen/llama/llama_model.py @@ -468,16 +468,39 @@ def train_model( return global_step, training_loss @torch.inference_mode() - def predict(self, sentences: List[str], keep_prompt: bool = False, - max_length: int = None, add_system_prompt=False, **kwargs): + def predict( + self, + sentences: List[str], + keep_prompt: bool = False, + add_system_prompt=False, + max_length: int = 256, + temperature: float = 0.95, + top_p: float = 0.9, + top_k: int = 40, + do_sample: bool = True, + repetition_penalty: float = 1.3, + length_penalty: float = 2.0, + num_beams: int = 1, + num_return_sequences: int = 1, + **kwargs + ) -> List[str]: """ Performs predictions on a list of text. Args: sentences: A python list of text (str) to be sent to the model for prediction. Note that the prefix should be prepended to the text. keep_prompt: Whether to keep the prompt in the generated text. - max_length: The maximum length of the generated text. add_system_prompt: Whether to add the system prompt to the prompt text. + max_length: The maximum length of the generated text. + temperature: The value used to module the next token probabilities. + top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. + top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering. + do_sample: Whether or not to use sampling ; use greedy decoding otherwise. + repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: The parameter that penalizes longer sequences. + num_beams: The number of beams to use for beam search. 1 means no beam search. + num_return_sequences: The number of independently computed returned sequences for each element in the batch. + **kwargs: Additional arguments for generating sequences. Returns: preds: A python list of the generated sequences. @@ -504,16 +527,16 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, inputs = self.tokenizer(batch, padding=True, return_tensors='pt') generation_config = GenerationConfig( max_new_tokens=max_length if max_length else self.args.max_length, - temperature=self.args.temperature, - top_p=self.args.top_p, - top_k=self.args.top_k, - do_sample=self.args.do_sample, - repetition_penalty=self.args.repetition_penalty, - length_penalty=self.args.length_penalty, - num_beams=self.args.num_beams, + temperature=temperature if temperature is not None else self.args.temperature, + top_p=top_p if top_p else self.args.top_p, + top_k=top_k if top_k else self.args.top_k, + do_sample=do_sample if do_sample is not None else self.args.do_sample, + repetition_penalty=repetition_penalty if repetition_penalty else self.args.repetition_penalty, + length_penalty=length_penalty if length_penalty else self.args.length_penalty, + num_beams=num_beams if num_beams else self.args.num_beams, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.pad_token_id, - num_return_sequences=self.args.num_return_sequences, + num_return_sequences=num_return_sequences if num_return_sequences else self.args.num_return_sequences, return_dict_in_generate=True, output_scores=True, **kwargs, @@ -536,8 +559,15 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, return all_outputs @torch.inference_mode() - def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: bool = False, - max_length: int = 2048, add_system_prompt=True, **kwargs): + def chat( + self, + query: str, + history: List[Tuple[str, str]] = None, + keep_prompt: bool = False, + add_system_prompt=True, + max_length: int = 2048, + **kwargs + ): """ Chat with the model :param query: