diff --git a/generate.py b/generate.py index 69d817723..eff086afd 100644 --- a/generate.py +++ b/generate.py @@ -110,11 +110,11 @@ def from_args(cls, args): ) return cls( - prompt=args.prompt, + prompt=getattr(args, "prompt", ""), encoded_prompt=None, chat_mode=args.chat, gui_mode=args.gui, - num_samples=args.num_samples, + num_samples=getattr(args, "num_samples", 1), max_new_tokens=args.max_new_tokens, top_k=args.top_k, temperature=args.temperature,