Skip to content

Commit

Permalink
fixed: #41
Browse files Browse the repository at this point in the history
  • Loading branch information
shibing624 committed Jun 13, 2023
1 parent bbceb9d commit ee420bc
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 48 deletions.
56 changes: 43 additions & 13 deletions textgen/bloom/bloom_model.py
Expand Up @@ -459,16 +459,39 @@ def train_model(
return global_step, training_loss

@torch.inference_mode()
def predict(self, sentences: List[str], keep_prompt: bool = False, max_length: int = None,
add_system_prompt=False, **kwargs):
def predict(
self,
sentences: List[str],
keep_prompt: bool = False,
add_system_prompt=False,
max_length: int = 256,
temperature: float = 0.95,
top_p: float = 0.9,
top_k: int = 40,
do_sample: bool = True,
repetition_penalty: float = 1.3,
length_penalty: float = 2.0,
num_beams: int = 1,
num_return_sequences: int = 1,
**kwargs
):
"""
Performs predictions on a list of text.
Args:
sentences: A python list of text (str) to be sent to the model for prediction. Note that the prefix should be prepended to the text.
keep_prompt: Whether to keep the prompt in the generated text.
add_system_prompt: Whether to add the system prompt to the prompt text.
max_length: The maximum length of the generated text.
add_system_prompt: Whether to add the system prompt to the prompt text. default: False
temperature: The value used to module the next token probabilities.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
do_sample: Whether or not to use sampling ; use greedy decoding otherwise.
repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty.
length_penalty: The parameter that penalizes longer sequences.
num_beams: The number of beams to use for beam search. 1 means no beam search.
num_return_sequences: The number of independently computed returned sequences for each element in the batch.
**kwargs: Additional arguments for generating sequences.
Returns:
preds: A python list of the generated sequences.
Expand All @@ -495,16 +518,16 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, max_length: i
inputs = self.tokenizer(batch, padding=True, return_tensors='pt').to(self.device)
generation_config = GenerationConfig(
max_new_tokens=max_length if max_length else self.args.max_length,
temperature=self.args.temperature,
top_p=self.args.top_p,
top_k=self.args.top_k,
do_sample=self.args.do_sample,
repetition_penalty=self.args.repetition_penalty,
length_penalty=self.args.length_penalty,
num_beams=self.args.num_beams,
temperature=temperature if temperature is not None else self.args.temperature,
top_p=top_p if top_p else self.args.top_p,
top_k=top_k if top_k else self.args.top_k,
do_sample=do_sample if do_sample is not None else self.args.do_sample,
repetition_penalty=repetition_penalty if repetition_penalty else self.args.repetition_penalty,
length_penalty=length_penalty if length_penalty else self.args.length_penalty,
num_beams=num_beams if num_beams else self.args.num_beams,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
num_return_sequences=self.args.num_return_sequences,
num_return_sequences=num_return_sequences if num_return_sequences else self.args.num_return_sequences,
return_dict_in_generate=True,
output_scores=True,
**kwargs,
Expand All @@ -523,8 +546,15 @@ def predict(self, sentences: List[str], keep_prompt: bool = False, max_length: i
return all_outputs

@torch.inference_mode()
def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: bool = False,
max_length: int = 2048, add_system_prompt=True, **kwargs):
def chat(
self,
query: str,
history: List[Tuple[str, str]] = None,
keep_prompt: bool = False,
add_system_prompt=True,
max_length: int = 2048,
**kwargs
):
"""Chat with the model."""
if history is None:
history = []
Expand Down
89 changes: 67 additions & 22 deletions textgen/chatglm/chatglm_model.py
Expand Up @@ -462,16 +462,39 @@ def process_response(self, response):
return response

@torch.inference_mode()
def predict(self, sentences: List[str], keep_prompt: bool = False,
max_length: int = None, add_system_prompt=False, **kwargs):
def predict(
self,
sentences: List[str],
keep_prompt: bool = False,
add_system_prompt=False,
max_length: int = 256,
temperature: float = 0.95,
top_p: float = 0.7,
top_k: int = 40,
do_sample: bool = True,
repetition_penalty: float = 1.0,
length_penalty: float = 2.0,
num_beams: int = 1,
num_return_sequences: int = 1,
**kwargs
) -> List[str]:
"""
Performs predictions on a list of text.
Args:
sentences: A python list of text (str) to be sent to the model for prediction.
keep_prompt: Whether to keep the prompt in the generated text.
max_length: The maximum length of the generated text.
add_system_prompt: Whether to add the system prompt to the prompt text.
max_length: The maximum length of the generated text.
temperature: The value used to module the next token probabilities.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
do_sample: Whether or not to use sampling ; use greedy decoding otherwise.
repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty.
length_penalty: The parameter that penalizes longer sequences.
num_beams: The number of beams to use for beam search. 1 means no beam search.
num_return_sequences: The number of independently computed returned sequences for each element in the batch.
**kwargs: Additional arguments for generating sequences.
Returns:
preds: A python list of the generated sequences.
Expand All @@ -498,16 +521,16 @@ def predict(self, sentences: List[str], keep_prompt: bool = False,
inputs = self.tokenizer(batch, padding=True, return_tensors='pt').to(self.device)
gen_kwargs = {
"max_new_tokens": max_length if max_length else self.args.max_length,
"temperature": self.args.temperature,
"top_p": self.args.top_p,
"top_k": self.args.top_k,
"do_sample": self.args.do_sample,
"repetition_penalty": self.args.repetition_penalty,
"length_penalty": self.args.length_penalty,
"num_beams": self.args.num_beams,
"temperature": temperature if temperature is not None else self.args.temperature,
"top_p": top_p if top_p else self.args.top_p,
"top_k": top_k if top_k else self.args.top_k,
"do_sample": do_sample if do_sample is not None else self.args.do_sample,
"repetition_penalty": repetition_penalty if repetition_penalty else self.args.repetition_penalty,
"length_penalty": length_penalty if length_penalty else self.args.length_penalty,
"num_beams": num_beams if num_beams else self.args.num_beams,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
"num_return_sequences": self.args.num_return_sequences,
"num_return_sequences": num_return_sequences if num_return_sequences else self.args.num_return_sequences,
**kwargs
}
outputs = self.model.generate(**inputs, **gen_kwargs)
Expand All @@ -525,8 +548,15 @@ def predict(self, sentences: List[str], keep_prompt: bool = False,
return all_outputs

@torch.inference_mode()
def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: bool = False,
max_length: int = 2048, add_system_prompt=True, **kwargs):
def chat(
self,
query: str,
history: List[Tuple[str, str]] = None,
keep_prompt: bool = False,
add_system_prompt=True,
max_length: int = 2048,
**kwargs
):
"""
Chat with the model
:param query:
Expand All @@ -553,22 +583,37 @@ def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: b
return response, history

@torch.inference_mode()
def stream_chat(self, query: str, history: List[Tuple[str, str]] = None,
max_length: int = 2048, add_system_prompt=True, **kwargs):
def stream_chat(
self,
query: str,
history: List[Tuple[str, str]] = None,
add_system_prompt=True,
max_length: int = 2048,
temperature: float = 0.95,
top_p: float = 0.7,
top_k: int = 40,
do_sample: bool = True,
repetition_penalty: float = 1.0,
length_penalty: float = 2.0,
num_beams: int = 1,
num_return_sequences: int = 1,
**kwargs
):
"""Chat with the model in a streaming fashion"""
if history is None:
history = []
gen_kwargs = {
"max_new_tokens": max_length if max_length else self.args.max_length,
"temperature": self.args.temperature,
"top_p": self.args.top_p,
"do_sample": self.args.do_sample,
"repetition_penalty": self.args.repetition_penalty,
"length_penalty": self.args.length_penalty,
"num_beams": self.args.num_beams,
"temperature": temperature if temperature is not None else self.args.temperature,
"top_p": top_p if top_p else self.args.top_p,
"top_k": top_k if top_k else self.args.top_k,
"do_sample": do_sample if do_sample is not None else self.args.do_sample,
"repetition_penalty": repetition_penalty if repetition_penalty else self.args.repetition_penalty,
"length_penalty": length_penalty if length_penalty else self.args.length_penalty,
"num_beams": num_beams if num_beams else self.args.num_beams,
"eos_token_id": self.tokenizer.eos_token_id,
"pad_token_id": self.tokenizer.pad_token_id,
"num_return_sequences": self.args.num_return_sequences,
"num_return_sequences": num_return_sequences if num_return_sequences else self.args.num_return_sequences,
**kwargs
}
if not history:
Expand Down
56 changes: 43 additions & 13 deletions textgen/llama/llama_model.py
Expand Up @@ -468,16 +468,39 @@ def train_model(
return global_step, training_loss

@torch.inference_mode()
def predict(self, sentences: List[str], keep_prompt: bool = False,
max_length: int = None, add_system_prompt=False, **kwargs):
def predict(
self,
sentences: List[str],
keep_prompt: bool = False,
add_system_prompt=False,
max_length: int = 256,
temperature: float = 0.95,
top_p: float = 0.9,
top_k: int = 40,
do_sample: bool = True,
repetition_penalty: float = 1.3,
length_penalty: float = 2.0,
num_beams: int = 1,
num_return_sequences: int = 1,
**kwargs
) -> List[str]:
"""
Performs predictions on a list of text.
Args:
sentences: A python list of text (str) to be sent to the model for prediction. Note that the prefix should be prepended to the text.
keep_prompt: Whether to keep the prompt in the generated text.
max_length: The maximum length of the generated text.
add_system_prompt: Whether to add the system prompt to the prompt text.
max_length: The maximum length of the generated text.
temperature: The value used to module the next token probabilities.
top_p: The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling.
top_k: The number of highest probability vocabulary tokens to keep for top-k-filtering.
do_sample: Whether or not to use sampling ; use greedy decoding otherwise.
repetition_penalty: The parameter for repetition penalty. 1.0 means no penalty.
length_penalty: The parameter that penalizes longer sequences.
num_beams: The number of beams to use for beam search. 1 means no beam search.
num_return_sequences: The number of independently computed returned sequences for each element in the batch.
**kwargs: Additional arguments for generating sequences.
Returns:
preds: A python list of the generated sequences.
Expand All @@ -504,16 +527,16 @@ def predict(self, sentences: List[str], keep_prompt: bool = False,
inputs = self.tokenizer(batch, padding=True, return_tensors='pt')
generation_config = GenerationConfig(
max_new_tokens=max_length if max_length else self.args.max_length,
temperature=self.args.temperature,
top_p=self.args.top_p,
top_k=self.args.top_k,
do_sample=self.args.do_sample,
repetition_penalty=self.args.repetition_penalty,
length_penalty=self.args.length_penalty,
num_beams=self.args.num_beams,
temperature=temperature if temperature is not None else self.args.temperature,
top_p=top_p if top_p else self.args.top_p,
top_k=top_k if top_k else self.args.top_k,
do_sample=do_sample if do_sample is not None else self.args.do_sample,
repetition_penalty=repetition_penalty if repetition_penalty else self.args.repetition_penalty,
length_penalty=length_penalty if length_penalty else self.args.length_penalty,
num_beams=num_beams if num_beams else self.args.num_beams,
eos_token_id=self.tokenizer.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id,
num_return_sequences=self.args.num_return_sequences,
num_return_sequences=num_return_sequences if num_return_sequences else self.args.num_return_sequences,
return_dict_in_generate=True,
output_scores=True,
**kwargs,
Expand All @@ -536,8 +559,15 @@ def predict(self, sentences: List[str], keep_prompt: bool = False,
return all_outputs

@torch.inference_mode()
def chat(self, query: str, history: List[Tuple[str, str]] = None, keep_prompt: bool = False,
max_length: int = 2048, add_system_prompt=True, **kwargs):
def chat(
self,
query: str,
history: List[Tuple[str, str]] = None,
keep_prompt: bool = False,
add_system_prompt=True,
max_length: int = 2048,
**kwargs
):
"""
Chat with the model
:param query:
Expand Down

0 comments on commit ee420bc

Please sign in to comment.