diff --git a/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py b/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py
index ebfa2d663..950172f81 100644
--- a/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py
+++ b/recipes/quickstart/inference/local_inference/chat_completion/chat_completion.py
@@ -3,55 +3,61 @@
 
 # from accelerate import init_empty_weights, load_checkpoint_and_dispatch
 
-import fire
 import json
 import os
+
 import sys
 
+import fire
+
 import torch
-from transformers import AutoTokenizer
+from accelerate.utils import is_xpu_available
 
 from llama_recipes.inference.chat_utils import read_dialogs_from_file
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
-from accelerate.utils import is_xpu_available
+from transformers import AutoTokenizer
+
 
 def main(
     model_name,
-    peft_model: str=None,
-    quantization: str = None, # Options: 4bit, 8bit
-    max_new_tokens =256, #The maximum numbers of tokens to generate
-    min_new_tokens:int=0, #The minimum numbers of tokens to generate
-    prompt_file: str=None,
-    seed: int=42, #seed value for reproducibility
-    safety_score_threshold: float=0.5,
-    do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
-    use_cache: bool=True,  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
-    top_p: float=1.0, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
-    temperature: float=1.0, # [optional] The value used to modulate the next token probabilities.
-    top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
-    repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
-    length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation.
-    enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
-    enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
-    enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
-    use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    peft_model: str = None,
+    quantization: str = None,  # Options: 4bit, 8bit
+    max_new_tokens=256,  # The maximum numbers of tokens to generate
+    min_new_tokens: int = 0,  # The minimum numbers of tokens to generate
+    prompt_file: str = None,
+    seed: int = 42,  # seed value for reproducibility
+    safety_score_threshold: float = 0.5,
+    do_sample: bool = True,  # Whether or not to use sampling ; use greedy decoding otherwise.
+    use_cache: bool = True,  # [optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
+    top_p: float = 1.0,  # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
+    temperature: float = 1.0,  # [optional] The value used to modulate the next token probabilities.
+    top_k: int = 50,  # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
+    repetition_penalty: float = 1.0,  # The parameter for repetition penalty. 1.0 means no penalty.
+    length_penalty: int = 1,  # [optional] Exponential penalty to the length that is used with beam-based generation.
+    enable_azure_content_safety: bool = False,  # Enable safety check with Azure content safety api
+    enable_sensitive_topics: bool = False,  # Enable check for sensitive topics using AuditNLG APIs
+    enable_saleforce_content_safety: bool = True,  # Enable safety check woth Saleforce safety flan t5
+    use_fast_kernels: bool = False,  # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     enable_llamaguard_content_safety: bool = False,
-    **kwargs
+    **kwargs,
 ):
+
     if prompt_file is not None:
         assert os.path.exists(
             prompt_file
         ), f"Provided Prompt file does not exist {prompt_file}"
 
-        dialogs= read_dialogs_from_file(prompt_file)
+        dialogs = read_dialogs_from_file(prompt_file)
 
     elif not sys.stdin.isatty():
         dialogs = "\n".join(sys.stdin.readlines())
         try:
             dialogs = json.loads(dialogs)
         except:
-            print("Could not parse json from stdin. Please provide a json file with the user prompts. Exiting.")
+            print(
+                "Could not parse json from stdin. Please provide a json file with the user prompts. Exiting."
+            )
             sys.exit(1)
     else:
         print("No user prompt provided. Exiting.")
@@ -59,7 +65,7 @@ def main(
 
     print(f"User dialogs:\n{dialogs}")
     print("\n==================================\n")
-    
+
     # Set the seeds for reproducibility
     if is_xpu_available():
         torch.xpu.manual_seed(seed)
@@ -77,13 +83,16 @@ def main(
 
     with torch.no_grad():
         for idx, chat in enumerate(chats):
-            safety_checker = get_safety_checker(enable_azure_content_safety,
-                                        enable_sensitive_topics,
-                                        enable_saleforce_content_safety,
-                                        enable_llamaguard_content_safety,
-                                        )
+            safety_checker = get_safety_checker(
+                enable_azure_content_safety,
+                enable_sensitive_topics,
+                enable_saleforce_content_safety,
+                enable_llamaguard_content_safety,
+            )
             # Safety check of the user prompt
-            safety_results = [check(dialogs[idx][0]["content"]) for check in safety_checker]
+            safety_results = [
+                check(dialogs[idx][0]["content"]) for check in safety_checker
+            ]
             are_safe = all([r[1] for r in safety_results])
             if are_safe:
                 print(f"User prompt deemed safe.")
@@ -97,13 +106,15 @@ def main(
                         print(report)
                 print("Skipping the inferece as the prompt is not safe.")
                 sys.exit(1)  # Exit the program with an error status
-            tokens= torch.tensor(chat).long()
-            tokens= tokens.unsqueeze(0)
+            tokens = torch.tensor(chat).long()
+            tokens = tokens.unsqueeze(0)
             attention_mask = torch.ones_like(tokens)
             if is_xpu_available():
-                tokens= tokens.to("xpu:0")
+                tokens = tokens.to("xpu")
+                attention_mask = attention_mask.to("xpu")
             else:
-                tokens= tokens.to("cuda:0")
+                tokens = tokens.to("cuda")
+                attention_mask = attention_mask.to("cuda")
             outputs = model.generate(
                 input_ids=tokens,
                 attention_mask=attention_mask,
@@ -115,7 +126,7 @@ def main(
                 top_k=top_k,
                 repetition_penalty=repetition_penalty,
                 length_penalty=length_penalty,
-                **kwargs
+                **kwargs,
             )
 
             output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -136,6 +147,5 @@ def main(
                         print(report)
 
 
-
 if __name__ == "__main__":
     fire.Fire(main)