Skip to content

Conversation

lkhphuc
Copy link
Contributor

@lkhphuc lkhphuc commented Oct 6, 2025

Follow up on the previous comment #1615 (comment), in this PR I refactor the bit about special tokens by subclassing the tokenizer and add the special tokens as class attribute there.

Also, we are adding the special tokens dynamically to the tokenizer, and the tokenizer will create new tokens if they are not already exists in the tokenizer's vocab.

This make it cleaner later when we provide the script to load pretrained LLM, so that we don't have to instruct users to separately download and modify the tokenizer's config manually.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 6, 2025
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactor, LGTM

pixel_values: torch.Tensor,
grid_thw: torch.Tensor,
special_tokens: SpecialTokens,
img_id: int,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: If we don't read the name img_id in the context of tokenizer, it might have other meanings. Previously we refer to it as tokenizer.img_id which should be fine, but here can we rename it to something like img_token_id in forward()

Copy link
Member

@H-Huang H-Huang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for refactoring!

from torchtitan.tools.logging import logger

from ..model.args import SpecialTokens
from ..tokenizer import VLMTokenizer as Tokenizer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel like leaving it just as VLMTokenizer is fine?

@lkhphuc
Copy link
Contributor Author

lkhphuc commented Oct 8, 2025

Thanks for the review. I have just fixed lint after review's changes.

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a concern around if this change is fine-tune friendly.


[model]
name = "llama3-siglip2"
name = "vlm"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a confusion introduced by my earlier PR #1740. Now job_config.model.name is decoupled with train_spec.name.

I think we should probably rename this to job_config.model.folder. WDYT? @fegin @wwwjn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we want to make these 2 field aligned, not decoupled. We use train_spec.name to import module, and we use current job_config.model.name to find the correct train_spec. By changing the name, we want user could know clearly train_spec.name should be the same as current job_config.model.name, right?

I think both old and new names works for me, as our current structure is always 1 to 1 mapping (eg, 1 folder contains only 1 train_spec, which represents only 1 model type)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use train_spec.name to import module

I don't think this is true, except for things explicitly registered into extra_train_specs? @wwwjn

By changing the name, we want user could know clearly train_spec.name should be the same as current job_config.model.name

But right now we don't have a way to restrict that. In general we shouldn't let user define redundant names anyway.

The ambiguity is really coming from the dynamic import_module based on name search -- you don't see the TrainSpec name until you specify model folder name in toml file.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if you want to support this, but locally I do have one model folder with multiple train_spec via register_train_spec directly. I currently workaround by just put the import of that folder back in the model/__init__.py to get all my train_specs.

If you want to support this case with dynamic import, I think adding another job_config.model.folder and let get_train_spec() -> dict[TrainSpec] return a list of train specs make sense.
Then we index the actual train spec with job_config.model.name.
If model's folder (or name) is not defined, it will be the same as model's name (or fodler).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm seeing if we could remove TrainSpec.name #1850
I think it doesn't hurt out-of-repo use case and make things unambiguous. Please let me you if you are OK.

self.eoi_token,
]
]
num_new_tokens = self.tokenizer.add_special_tokens(_special_tokens)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the difference between this simple treatment vs. the complicated logic in components/tokenizer.py?

I feel the key difference is that component/tokenizer.py has an infer step, where these tokens could have been defined for the underlying model.

def _get_token_from_config(self, config: dict[str, Any], key: str) -> Optional[str]:

for token_id_str, token_config in added_tokens_decoder.items():

While the simple treatment here may be enough to train a VLM given a pretrained LLM, I'm worried it might cause trouble if people would like to fine-tune a trained VLM, as the the trained checkpoint should come with a tokenizer having these tokens which may / may not have the same content as what you defined here.

pad_token and img_token may not matter as they won't affect model computation, but boi_token and eoi_token might, as they could have trained embedding in the model? Please correct me if this sounds wrong.

Even without above consideration, there could be id collision between e.g. "<|pad|>" for image padding here or for other padding purposes embedded in the origin model tokenizer?

I had thought that the VLM tokenizer override methods / variables like _infer_special_tokens and standard_keys, although it might needs some work to make sure if works in both pretrain and finetune cases

def _infer_special_tokens(self):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain the difference between this simple treatment vs. the complicated logic in components/tokenizer.py?
...
I had thought that the VLM tokenizer override methods / variables like _infer_special_tokens and standard_keys, although it might needs some work to make sure if works in both pretrain and finetune cases

I agree the default tokenizer is a bit complicated, and possibly there's some redundant regarding the creation of bos_id, eos_id attributes. At first I did try to refactor it, but decided not to in case I am not aware of some edge cases.

I feel the key difference is that component/tokenizer.py has an infer step, where these tokens could have been defined for the underlying model.

For the infer step, the Tokenizer package already handles if token is already exists. It will only create new tokens that has not existed (and return the count of new tokens). https://huggingface.co/docs/tokenizers/api/tokenizer#tokenizers.Tokenizer.add_special_tokens
I think main purpose of the _infer_special_tokens() method is to get all the already defined special tokens from all the config files?

While the simple treatment here may be enough to train a VLM given a pretrained LLM, I'm worried it might cause trouble if people would like to fine-tune a trained VLM, as the the trained checkpoint should come with a tokenizer having these tokens which may / may not have the same content as what you defined here.

My logical flow is like this:

  • VLM needs some special tokens, different choices of VLM might needs different number of special tokens. So a vlm need to declare what it needs, either in the model_args or in a custom tokenizer (like here currently).
  • For both pretraining or finetuning VLM, it will inherit the tokenizer of the LLM. To match the logical function of the special tokens, we can only match by the token string itself.
  • And since different LLM uses different token strings for the same function (<|pad|>, <pad>, <|padding|> etc), there is no way but to manually check and verify for each pre-trained LLM/VLM we want to use.

If we want to flexibly support different base LLM/VLM without subclassing a tokenizer everytime, we could:

  • Declare the dict of required special tokens and their token strings in the model_args.
  • Dynamically add all those special tokens to the base tokenizer if not exist (already handled by underlying tokenizer class)
  • This function can be defined in this VLMTokenizer subclass, or merge with the current tokenizer like:
    __init__(self,....):
       ...
        self._infer_special_tokens()
        self._infer_should_add_bos_eos()
        self._add_special_tokens(special_token_dict)  # Could be empty
  • Or refactoring the current tokenizer is another option:
    __init__(self, tokenizer_path, special_tokens: dict | None):
       ...
        self._infer_special_tokens(special_tokens)
        self._infer_should_add_bos_eos()

This look cleanest to me and was my initial approach, but I stopped due to the large code changes and uncertainty that I won't break some edgecase for some kind of tokenizer config.

WDYT?

Copy link
Contributor

@tianyu-l tianyu-l Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lkhphuc

VLM needs some special tokens, different choices of VLM might needs different number of special tokens. So a vlm need to declare what it needs, either in the model_args or in a custom tokenizer (like here currently).

And since different LLM uses different token strings for the same function (<|pad|>, , <|padding|> etc), there is no way but to manually check and verify for each pre-trained LLM/VLM we want to use.

Question 1:
The current LLM tokenizer doesn't seem to need to specify, as it only declares a fixed set of things without actually specifying the content. Instead it does lookup via configs. Can't we do this in VLM?

The only problem is that if we are doing VLM pretraining on top of existing LLM, then the LLM tokenizer may not have those special token's contents defined in the configs / tokenizer. So we should have fallback. But that also require the LLM's vocab to leave space for new special tokens?

Question 2:
It seems not all keys in standard_keys are used. https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/tokenizer.py#L259-L267 The only used keys are bos_id and eos_id? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/tokenizer.py#L58-L59 O/w how would they be used (presumably by dataloader) if they are not even declared as a variable.

If we want to flexibly support different base LLM/VLM without subclassing a tokenizer everytime, we could:
Declare the dict of required special tokens and their token strings in the model_args.
Dynamically add all those special tokens to the base tokenizer if not exist (already handled by underlying tokenizer class)
This function can be ... merge with the current tokenizer like:

I'm OK with this if we have to do it. We can extend build_hf_tokenizer to accept a dict of special tokens. I'm not sure where we should define these special tokens, as model / model args shouldn't have dependency on them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question 1:
The current LLM tokenizer doesn't seem to need to specify, as it only declares a fixed set of things without actually specifying the content. Instead it does lookup via configs. Can't we do this in VLM?

Sorry not sure what you means exactly. If you means in the LLM's tokenizer_config.json we have e.g {"start_of_thinking": "<think>", "end_of_thinking": "</think>"} then sure we could modify the tokenizer_config.json, the vocab.json and special_tokens.json like I previously did for the test tokenizer. But this would requires some manual work for every base LLMs. Which is why I opt to do it programmatically in this PR.

In either case, one still need to know/provide either the correct token name e.g "start_of_thinking", "start_of_think" or "think_start" or the token string <think>, <|think|> etc. The standard_keys token are just popular convention and there is no guarantee that it will always exists for every tokenizer, maybe except for bos,eos and pad token.

For VLM pretraining, we presume it has no special tokens related to VLM so we can just define new. For finetuning VLM, one need to check what are the relevant tokens actually named in the tokenizer to lookup its token string. In some cases, the pretrained VLM tokenizer could come with a couple ambiguous tokens likereserved_1, reserved_2, reserved_3, and one must check the official doc or reference implementation to know which one was actually used for which particular purpose.

The only problem is that if we are doing VLM pretraining on top of existing LLM, then the LLM tokenizer may not have those special token's contents defined in the configs / tokenizer. So we should have fallback. But that also require the LLM's vocab to leave space for new special tokens?

Yes, this fallback is standard behaviour for tokenizer.add_special_tokens. If the token to be added is already in the vocab, it will be marked as special if it isn't already so. If that token does not exists, it will expand the vocab and assign that token automatically.

Question 2:
It seems not all keys in standard_keys are used. https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/tokenizer.py#L259-L267 The only used keys are bos_id and eos_id? https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/tokenizer.py#L58-L59 O/w how would they be used (presumably by dataloader) if they are not even declared as a variable.

Yes I believe the other tokens are not being used too, and not possible to access in current tokenizer. In transformers, the AutoTokenizer.from_pretrained(...) will automatically create the attributes for all the special tokens, similar to what I'm doing here for the special_tokens dict.

I'm OK with this if we have to do it. We can extend build_hf_tokenizer to accept a dict of special tokens. I'm not sure where we should define these special tokens, as model / model args shouldn't have dependency on them?

I will add this on top of the current tokenizer then. I think define this dict in the model args is sensible, because it's what the model required, and it can be different from model to model (some might not use the img_start img_end at all). Then the tokenizer will access model's arg and create any tokens required but not yet exists.
It's the config provider to correctly reuse any existing special tokens that were pretrained for that functionality.

Copy link
Contributor

@tianyu-l tianyu-l Oct 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lkhphuc
Thanks! Please see if I get the idea or not.

There are two things: token str and token id.

  • Token str is what data loader produces.
  • Token id is what model recognizes.
  • Tokenizer is the bridge between token str and token id.

token str

Token str has nothing to do with model, as model will only access tokenized inputs, which are always ids. Token str will only be produced by data loader.

Model code won't explicitly use any token ids, in the sense that it won't (and probably shouldn't) have any code like self.mlp(.., self.pad_id). Model code will implicitly support various token ids, via Embedding layer.

token id:

For pretraining, it doesn't matter which id is assigned to what special purpose. E.g. is 128001 assigned to padding, 128002 assigned to img_start, or the other way? It doesn't matter. What matters is that the vocab size is larger than the tokenizer's vocab length, o/w the nn.Embedding will hit OutOfIndexError.

It won't automatically always be the case! E.g. in a toy example, if I have text vocab to be 1024, and the vocab dimension of pretained LLM is exactly 1024. Then without "vocab extension", this particular LLM can't be used for VLM training, because there is no room in the embedding table to host extra img token ids.

For finetuning, it's important that torchtitan tokenizer maps token str consistently to token id, compared with pretraining. The consistency is in the sense that

  • If pretraining uses <padding> as the padding token str and maps to 128001. It's important for torchtitan finetuning to map the padding token also to 128001, but torchtitan tokenizer can use <pad> as the token str, as long as the data loader produces <pad> instead of <padding>.

I think define this dict in the model args is sensible, because it's what the model required, and it can be different from model to model (some might not use the img_start img_end at all).

I agree that this is different from model to model. But I think it's because different models depends on different tokenizers (during pretraing). If what I said above is correct, I think we should separate model definition and tokenizer configs because

  • Yes, the token id part of tokenizer should be consistent with (pretrained) model embedding.
  • But, the token str part should be independent from model, and only be tied with data loader.
  • So for this "bridge" I think the full config shouldn't be inside each party, data loader or model.

So I would still prefer we have one tokenizer class for each model, if necessary. I think what you are doing in VLMTokenizer makes sense to me.

For finetuning, it's very important for the model owner to make sure the definition of VLMTokenizer is consistent with the tokenizer_configs.json from HuggingFace (in terms of token id, not token str!), so that whatever we do with data loader, the generated token ids can be understood by model embedding. In other words, we need to achieve whatever AutoTokenizer.from_pretrained does.

I think this means:
Whatever token data loader will use, it must be a variable in tokenizer class. E.g. tokenizer.pad_id, tokenizer.pad_token.

  • If we are doing finetuning with some pretrained model and their tokenizer config file, we don't need to define the token str.
  • If we are doing pretraining VLM (on top of a trained LLM), we need to explicitly define the token str, because the LLM tokenizer may not have them defined.

BTW

Yes I believe the other tokens are not being used too, and not possible to access in current tokenizer. In transformers, the AutoTokenizer.from_pretrained(...) will automatically create the attributes for all the special tokens, similar to what I'm doing here for the special_tokens dict.

Shall we remove those unused tokens in components/tokenizer.py?
In general, I feel we should be more explicit, less implicit if we agree on what Tokenizer in torchtitan should do.

cc @H-Huang @wwwjn

tianyu-l pushed a commit that referenced this pull request Oct 9, 2025
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on #1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 13, 2025
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 15, 2025
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants