Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix quantified inference #302

Merged
merged 5 commits into from
Sep 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,21 @@ def setup():
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
)

base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type
)
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
Expand Down
14 changes: 9 additions & 5 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,21 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT):
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
else:
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
)

base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type
)
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
Expand Down
15 changes: 9 additions & 6 deletions scripts/openai_server_demo/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,20 @@
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)

if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
)
base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto' if not args.only_cpu else None,
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type
)
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
Expand Down