diff --git a/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py b/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py index ebfa2d663..950172f81 100644 --- a/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py +++ b/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py @@ -3,55 +3,61 @@ # from accelerate import init_empty_weights, load_checkpoint_and_dispatch -import fire import json import os + import sys +import fire + import torch -from transformers import AutoTokenizer +from accelerate.utils import is_xpu_available from llama_recipes.inference.chat_utils import read_dialogs_from_file from llama_recipes.inference.model_utils import load_model, load_peft_model from llama_recipes.inference.safety_utils import get_safety_checker -from accelerate.utils import is_xpu_available +from transformers import AutoTokenizer + def main( model_name, - peft_model: str=None, - quantization: str = None, # Options: 4bit, 8bit - max_new_tokens =256, #The maximum numbers of tokens to generate - min_new_tokens:int=0, #The minimum numbers of tokens to generate - prompt_file: str=None, - seed: int=42, #seed value for reproducibility - safety_score_threshold: float=0.5, - do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise. - use_cache: bool=True, #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. - top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. - temperature: float=1.0, # [optional] The value used to modulate the next token probabilities. - top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. - repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty. - length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. - enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api - enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs - enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5 - use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + peft_model: str = None, + quantization: str = None, # Options: 4bit, 8bit + max_new_tokens=256, # The maximum numbers of tokens to generate + min_new_tokens: int = 0, # The minimum numbers of tokens to generate + prompt_file: str = None, + seed: int = 42, # seed value for reproducibility + safety_score_threshold: float = 0.5, + do_sample: bool = True, # Whether or not to use sampling ; use greedy decoding otherwise. + use_cache: bool = True, # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + top_p: float = 1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: float = 1.0, # [optional] The value used to modulate the next token probabilities. + top_k: int = 50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. + repetition_penalty: float = 1.0, # The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: int = 1, # [optional] Exponential penalty to the length that is used with beam-based generation. + enable_azure_content_safety: bool = False, # Enable safety check with Azure content safety api + enable_sensitive_topics: bool = False, # Enable check for sensitive topics using AuditNLG APIs + enable_saleforce_content_safety: bool = True, # Enable safety check woth Saleforce safety flan t5 + use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels enable_llamaguard_content_safety: bool = False, - **kwargs + **kwargs, ): + if prompt_file is not None: assert os.path.exists( prompt_file ), f"Provided Prompt file does not exist {prompt_file}" - dialogs= read_dialogs_from_file(prompt_file) + dialogs = read_dialogs_from_file(prompt_file) elif not sys.stdin.isatty(): dialogs = "\n".join(sys.stdin.readlines()) try: dialogs = json.loads(dialogs) except: - print("Could not parse json from stdin. Please provide a json file with the user prompts. Exiting.") + print( + "Could not parse json from stdin. Please provide a json file with the user prompts. Exiting." + ) sys.exit(1) else: print("No user prompt provided. Exiting.") @@ -59,7 +65,7 @@ def main( print(f"User dialogs:\n{dialogs}") print("\n==================================\n") - + # Set the seeds for reproducibility if is_xpu_available(): torch.xpu.manual_seed(seed) @@ -77,13 +83,16 @@ def main( with torch.no_grad(): for idx, chat in enumerate(chats): - safety_checker = get_safety_checker(enable_azure_content_safety, - enable_sensitive_topics, - enable_saleforce_content_safety, - enable_llamaguard_content_safety, - ) + safety_checker = get_safety_checker( + enable_azure_content_safety, + enable_sensitive_topics, + enable_saleforce_content_safety, + enable_llamaguard_content_safety, + ) # Safety check of the user prompt - safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker] + safety_results = [ + check(dialogs[idx][0]["content"]) for check in safety_checker + ] are_safe = all([r[1] for r in safety_results]) if are_safe: print(f"User prompt deemed safe.") @@ -97,13 +106,15 @@ def main( print(report) print("Skipping the inferece as the prompt is not safe.") sys.exit(1) # Exit the program with an error status - tokens= torch.tensor(chat).long() - tokens= tokens.unsqueeze(0) + tokens = torch.tensor(chat).long() + tokens = tokens.unsqueeze(0) attention_mask = torch.ones_like(tokens) if is_xpu_available(): - tokens= tokens.to("xpu:0") + tokens = tokens.to("xpu") + attention_mask = attention_mask.to("xpu") else: - tokens= tokens.to("cuda:0") + tokens = tokens.to("cuda") + attention_mask = attention_mask.to("cuda") outputs = model.generate( input_ids=tokens, attention_mask=attention_mask, @@ -115,7 +126,7 @@ def main( top_k=top_k, repetition_penalty=repetition_penalty, length_penalty=length_penalty, - **kwargs + **kwargs, ) output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) @@ -136,6 +147,5 @@ def main( print(report) - if __name__ == "__main__": fire.Fire(main)