Skip to content

Commit

Permalink
Update langchain hparams (#271)
Browse files Browse the repository at this point in the history
1. `temperature` and `repetition_penalty` should be moved to `pipeline_kwargs`
2. Keep other hparams consistent with `inference_hf.py`
  • Loading branch information
ymcui committed Sep 13, 2023
1 parent cedb6da commit ba4e228
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
13 changes: 9 additions & 4 deletions scripts/langchain/langchain_qa.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,16 @@
model = HuggingFacePipeline.from_model_id(model_id=model_path,
task="text-generation",
device=0,
pipeline_kwargs={
"max_new_tokens": 400,
"do_sample": True,
"temperature": 0.2,
"top_k": 40,
"top_p": 0.9,
"repetition_penalty": 1.1},
model_kwargs={
"torch_dtype" : load_type,
"low_cpu_mem_usage" : True,
"temperature": 0.2,
"repetition_penalty":1.1}
"torch_dtype": load_type,
"low_cpu_mem_usage": True}
)

if args.chain_type == "stuff":
Expand Down
13 changes: 9 additions & 4 deletions scripts/langchain/langchain_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,16 @@
model = HuggingFacePipeline.from_model_id(model_id=model_path,
task="text-generation",
device=0,
pipeline_kwargs={
"max_new_tokens": 400,
"do_sample": True,
"temperature": 0.2,
"top_k": 40,
"top_p": 0.9,
"repetition_penalty": 1.1},
model_kwargs={
"torch_dtype" : load_type,
"low_cpu_mem_usage" : True,
"temperature": 0.2,
"repetition_penalty":1.1}
"torch_dtype" : load_type,
"low_cpu_mem_usage" : True}
)

PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"])
Expand Down

0 comments on commit ba4e228

Please sign in to comment.