Skip to content

[Feature] return_assistant_tokens_mask for SFT #3014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions test/llm/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,24 @@ def test_content_base(self):
The result is""",
]

def test_history_assistant_mask(self):
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-3B")
for test_case in self.TEST_CASES:
history = History.from_text(test_case, chat_template_name="qwen")
proc = history.apply_chat_template(
tokenizer=tokenizer,
chat_template_name="qwen",
add_generation_prompt=False,
return_dict=True,
return_assistant_tokens_mask=True,
)
if "assistant" in history.role:
assert proc["assistant_masks"].any()
else:
assert not proc["assistant_masks"].any()

def test_history_completion(self):
"""Test the History class's handling of complete and incomplete messages."""

Expand Down
106 changes: 100 additions & 6 deletions torchrl/data/llm/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@

import torch


from tensordict import lazy_stack, LazyStackedTensorDict, list_to_stack, TensorClass
from tensordict.utils import _maybe_correct_neg_dim

from torchrl._utils import logger as torchrl_logger


_CHAT_TEMPLATES = {
"chatml_format": """{% for message in messages %}
{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}
Expand All @@ -24,7 +26,64 @@
{{- '<|im_start|>assistant\n' }}
{%- endif %}
""",
"qwen": """'{%- if tools %}\n {{- \'<|im_start|>system\\n\' }}\n {%- if messages[0][\'role\'] == \'system\' %}\n {{- messages[0][\'content\'] }}\n {%- else %}\n {{- \'You are a helpful assistant.\' }}\n {%- endif %}\n {{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}\n {%- for tool in tools %}\n {{- "\\n" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\"name\\": <function-name>, \\"arguments\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}\n{%- else %}\n {%- if messages[0][\'role\'] == \'system\' %}\n {{- \'<|im_start|>system\\n\' + messages[0][\'content\'] + \'<|im_end|>\\n\' }}\n {%- else %}\n {{- \'<|im_start|>system\\nYou are a helpful assistant.<|im_end|>\\n\' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == "user") or (message.role == "system" and not loop.first) or (message.role == "assistant" and not message.tool_calls) %}\n {{- \'<|im_start|>\' + message.role + \'\\n\' + message.content + \'<|im_end|>\' + \'\\n\' }}\n {%- elif message.role == "assistant" %}\n {{- \'<|im_start|>\' + message.role }}\n {%- if message.content %}\n {{- \'\\n\' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- \'\\n<tool_call>\\n{"name": "\' }}\n {{- tool_call.name }}\n {{- \'", "arguments": \' }}\n {{- tool_call.arguments | tojson }}\n {{- \'}\\n</tool_call>\' }}\n {%- endfor %}\n {{- \'<|im_end|>\\n\' }}\n {%- elif message.role == "tool" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}\n {{- \'<|im_start|>user\' }}\n {%- endif %}\n {{- \'\\n<tool_response>\\n\' }}\n {{- message.content }}\n {{- \'\\n</tool_response>\' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}\n {{- \'<|im_end|>\\n\' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- \'<|im_start|>assistant\\n\' }}\n{%- endif %}\n'""",
"qwen": """
{%- if tools %}
{{- '<|im_start|>system\\n' }}
{%- if messages[0]['role'] == 'system' %}
{{- messages[0]['content'] }}
{%- else %}
{{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}
{%- endif %}
{{- "\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>" }}
{%- for tool in tools %}
{{- "\\n" }}
{{- tool | tojson }}
{%- endfor %}
{{- "\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n" }}
{%- else %}
{%- if messages[0]['role'] == 'system' %}
{{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}
{%- else %}
{{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}
{%- endif %}
{%- endif %}
{%- for message in messages %}
{%- if (message.role == "user") or (message.role == "system" and not loop.first) %}
{{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}
{%- elif (message.role == "assistant" and not message.tool_calls) %}
{% generation %} {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }} {% endgeneration %}
{%- elif message.role == "assistant" %}
{% generation %}{{- '<|im_start|>' + message.role }}
{%- if message.content %}
{{- '\\n' + message.content }}
{%- endif %}
{%- for tool_call in message.tool_calls %}
{%- if tool_call.function is defined %}
{%- set tool_call = tool_call.function %}
{%- endif %}
{{- '\\n<tool_call>\\n{\\\"name\\\": \\\"' }}
{{- tool_call.name }}
{{- '\\\", \\\"arguments\\\": ' }}
{{- tool_call.arguments | tojson }}
{{- '}\\n</tool_call>' }}
{%- endfor %}
{{- '<|im_end|>\\n' }}{% endgeneration %}
{%- elif message.role == "tool" %}
{%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != "tool") %}
{{- '<|im_start|>user' }}
{%- endif %}
{{- '\\n<tool_response>\\n' }}
{{- message.content }}
{{- '\\n</tool_response>' }}
{%- if loop.last or (messages[loop.index0 + 1].role != "tool") %}
{{- '<|im_end|>\\n' }}
{%- endif %}
{%- endif %}
{%- endfor %}
{%- if add_generation_prompt %}
{% generation %}{{- '<|im_start|>assistant\\n' }}{% endgeneration %}
{%- endif %}
""",
}


Expand Down Expand Up @@ -210,12 +269,14 @@ def apply_chat_template(
tokenizer: transformers.AutoTokenizer | transformers.AutoProcessor, # noqa
add_generation_prompt: bool = True,
chat_template: str | None = None,
chat_template_name: Literal["chatml_format", "qwen"] | None = None,
continue_final_message: bool = False,
tokenize: bool = False,
tokenize: bool | None = None,
padding: bool | str = False,
truncation: bool | str = False,
return_tensors: str | None = "pt",
return_dict: bool = False,
return_tensors: str | None = None,
return_dict: bool | None = None,
return_assistant_tokens_mask: bool = False,
**kwargs,
):
"""Applies a chat template to the history.
Expand All @@ -224,37 +285,68 @@ def apply_chat_template(
tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor): The tokenizer to use.
add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to `True`.
chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
chat_template_name (Literal["chatml_format", "qwen"], optional): The name of the chat template to use.
Prevalent over `tokenizer.chat_template`. Defaults to `None`.
continue_final_message (bool, optional): Whether to continue the final message. Defaults to `False`.
tokenize (bool, optional): Whether to tokenize the output. Defaults to `False`.
padding (bool | str, optional): The padding strategy to use. Defaults to `False`.
truncation (bool | str, optional): The truncation strategy to use. Defaults to `False`.
return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
return_assistant_tokens_mask (bool, optional): Whether to return a mask of the assistant generated tokens.
For tokens generated by the assistant, the mask will contain `1`.
For user and system tokens, the mask will contain `0`.
This functionality is only available for chat templates that support it via the `{% generation %}` keyword.
Defaults to `False`.

.. note:: By default, the `"qwen"` chat template does not support this functionality. A modified version of the template
can be used by setting `chat_template_name="qwen"`, which will override the default template from the tokenizer.
For other tokenizers, similar edits can be made to the template and passed to the method via the `chat_template` argument.

**kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.

Returns:
The formatted history.
"""
if chat_template is None:
if tokenizer is None:
if chat_template_name is not None:
chat_template = _CHAT_TEMPLATES[chat_template_name]
chat_template_name = None
elif tokenizer is None:
raise RuntimeError(
"You must specify a tokenizer to use when chat_template is not specified."
)
chat_template = tokenizer.chat_template
else:
chat_template = tokenizer.chat_template
if chat_template is None:
chat_template = _CHAT_TEMPLATES["chatml_format"]
if tokenize is None:
if return_assistant_tokens_mask or return_tensors is not None:
tokenize = True
else:
tokenize = False
if tokenize:
if return_tensors is None:
return_tensors = "pt"
if return_dict is None and return_assistant_tokens_mask:
return_dict = True
elif return_dict is None:
return_dict = False

if self.ndim > 1:
return [
self[i].apply_chat_template(
tokenizer=tokenizer,
add_generation_prompt=add_generation_prompt,
chat_template=chat_template,
chat_template_name=chat_template_name,
tokenize=tokenize,
padding=padding,
truncation=truncation,
return_tensors=return_tensors,
continue_final_message=continue_final_message,
return_dict=return_dict,
return_assistant_tokens_mask=return_assistant_tokens_mask,
**kwargs,
)
for i in range(self.batch_size[0])
Expand All @@ -274,6 +366,8 @@ def apply_chat_template(
return_tensors=return_tensors,
continue_final_message=continue_final_message,
return_dict=return_dict,
return_assistant_tokens_mask=return_assistant_tokens_mask,
**kwargs,
)

@classmethod
Expand Down
Loading