-
-
Notifications
You must be signed in to change notification settings - Fork 4.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Mistral] Mistral-7B-v0.1 support #1196
Conversation
|
||
import torch | ||
from torch import nn | ||
from transformers import MistralConfig |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not work because MistralConfig
is not a regular model in HF transformers at the moment (v4.33.3). Could you define this config class just like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@timlacroix Besides this, it seems everything works fine!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok addressed. Will we need to change this back after the next release ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@timlacroix Yes. Once a new version of HF transformers is released, we will remove it.
The Mistral model is almost equivalent to llama in terms of quantizing the model, it would be super easy to extend support as I have already added Mistral in AutoAWQ. If you can modify this part below, you will enable AWQ quantized models:
After that, you should be able to run inference with the quantized model that is already available: https://huggingface.co/casperhansen/mistral-7b-instruct-v0.1-awq from vllm import LLM, SamplingParams
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(model="casperhansen/mistral-7b-instruct-v0.1-awq", quantization="awq", dtype="half")
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. As this PR is not modifiable, I will fix some miscellaneous issues right after merging this PR.
Co-authored-by: timlacroix <t@mistral.ai>
Co-authored-by: timlacroix <t@mistral.ai>
No description provided.