-
-
Notifications
You must be signed in to change notification settings - Fork 10.5k
[Frontend] Clean up type annotations for mistral tokenizer #8314
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Frontend] Clean up type annotations for mistral tokenizer #8314
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
elif part_type == "refusal": | ||
text = _RefusalParser(part)["refusal"] | ||
texts.append(text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that _RefusalParser
got left out previously, so I'm adding it here.
def _postprocess_messages(messages: List[ConversationMessage]) -> None: | ||
# per the Transformers docs & maintainers, tool call arguments in | ||
# assistant-role messages with tool_calls need to be dicts not JSON str - | ||
# this is how tool-use chat templates will expect them moving forwards | ||
# so, for messages that have tool_calls, parse the string (which we get | ||
# from openAI format) to dict | ||
for message in messages: | ||
if (message["role"] == "assistant" and "tool_calls" in message | ||
and isinstance(message["tool_calls"], list)): | ||
|
||
for item in message["tool_calls"]: | ||
item["function"]["arguments"] = json.loads( | ||
item["function"]["arguments"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume that mistral tokenizers will be able to handle tool calls internally since the output conversation
will not be used by mistral tokenizer.
prompt: Union[str, List[int]] | ||
if isinstance(tokenizer, MistralTokenizer): | ||
prompt = apply_mistral_chat_template( | ||
tokenizer, | ||
messages=messages, | ||
chat_template=chat_template, | ||
add_generation_prompt=add_generation_prompt, | ||
) | ||
else: | ||
prompt = apply_hf_chat_template( | ||
tokenizer, | ||
conversation=conversation, | ||
chat_template=chat_template, | ||
add_generation_prompt=add_generation_prompt, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main part of this PR. Notice that mistral tokenizer uses messages
while HF tokenizer uses conversation
. This is cleaner than having different parsing logic inside parse_chat_messages
as it avoids the need to handle different types of conversation
when generating the output request.
chat_template: Optional[str], | ||
**kwargs: Any, | ||
) -> List[int]: | ||
return tokenizer.apply_chat_template( | ||
messages=messages, | ||
chat_template=chat_template, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chat_template: Optional[str], | |
**kwargs: Any, | |
) -> List[int]: | |
return tokenizer.apply_chat_template( | |
messages=messages, | |
chat_template=chat_template, | |
**kwargs: Any, | |
) -> List[int]: | |
return tokenizer.apply_chat_template( | |
messages=messages, |
maybe out of scope for this PR, but mistral tokenizers will actually never need a chat_template
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, let's do this in another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR @DarkLight1337 - that's indeed much cleaner!
Think we don't need to pass the chat_template to the mistral tokenizer function at all anymore, but also happy to tackle this in another PR
…ect#8314) Signed-off-by: Alvant <alvasian@yandex.ru>
…ect#8314) Signed-off-by: Amit Garg <mitgarg17495@gmail.com>
…ect#8314) Signed-off-by: LeiWang1999 <leiwang1999@outlook.com>
The expected inputs and outputs for mistral tokenizer are different from HF's tokenizers. So, in this PR I have split them into different functions to avoid introducing many union types.
cc @patrickvonplaten since you originally worked on using mistral tokenizer in vLLM