Skip to content

Commit

Permalink
Fix quantized inference (#302)
Browse files Browse the repository at this point in the history
Fixed possible mismatches caused by high version dependencies.
  • Loading branch information
iMountTai committed Sep 21, 2023
1 parent ba4e228 commit 09eadc6
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
14 changes: 9 additions & 5 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,17 +166,21 @@ def setup():
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
)

base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type
)
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
Expand Down
14 changes: 9 additions & 5 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,17 +110,21 @@ def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT):
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
else:
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)
if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
)

base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto',
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type
)
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
Expand Down
15 changes: 9 additions & 6 deletions scripts/openai_server_demo/openai_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,20 @@
if args.lora_model is None:
args.tokenizer_path = args.base_model
tokenizer = LlamaTokenizer.from_pretrained(args.tokenizer_path, legacy=True)

if args.load_in_4bit or args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type,
)
base_model = LlamaForCausalLM.from_pretrained(
args.base_model,
torch_dtype=load_type,
low_cpu_mem_usage=True,
device_map='auto' if not args.only_cpu else None,
quantization_config=BitsAndBytesConfig(
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
bnb_4bit_compute_dtype=load_type
)
load_in_4bit=args.load_in_4bit,
load_in_8bit=args.load_in_8bit,
quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None
)

model_vocab_size = base_model.get_input_embeddings().weight.size(0)
Expand Down

0 comments on commit 09eadc6

Please sign in to comment.