Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize tokenizer.decode() Performance for List[int] Inputs #36872

Closed
n0gu-furiosa opened this issue Mar 21, 2025 · 2 comments · Fixed by #36885
Closed

Optimize tokenizer.decode() Performance for List[int] Inputs #36872

n0gu-furiosa opened this issue Mar 21, 2025 · 2 comments · Fixed by #36885
Labels
Feature request Request for a new feature

Comments

@n0gu-furiosa
Copy link
Contributor

Feature request

When calling tokenizer.decode() with a List[int] as token_ids, the method appears to be significantly slower than necessary due to redundant to_py_obj conversions.

Motivation

Example:

import time
from transformers import AutoTokenizer

tok = AutoTokenizer.from_pretrained('hf-internal-testing/llama-tokenizer')
token_ids = [869] * 2000

start = time.time()
for _ in range(1000):
    tok.decode(token_ids)
print(time.time() - start)

The trace results for the above code show that most of the time is spent on repeated to_py_obj calls, rather than in the actual _decode function:

Image

In this case, since the input is already a List[int], passing it through to_py_obj seems redundant. By adding a conditional check to bypass this line for List[int] inputs:

token_ids = to_py_obj(token_ids)

…the example code improves by nearly 10x in my environment (from ~7s to ~0.7s).

Your contribution

I wasn’t sure where the best place to apply this optimization would be—either within decode() or inside to_py_obj()—so I haven’t opened a PR yet. I’d be happy to contribute a fix if there’s guidance on where such a change would be most appropriate.

@n0gu-furiosa n0gu-furiosa added the Feature request Request for a new feature label Mar 21, 2025
@Rocketknight1
Copy link
Member

This is most likely caused by this section in to_py_obj():

if isinstance(obj, (dict, UserDict)):
    return {k: to_py_obj(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
    return [to_py_obj(o) for o in obj]

When a flat list is passed in, to_py_obj() will be called on each element of the list, which means 2000 function calls are required in your test. If you can figure out an optimization in that function that retains correct output without that speed penalty, we'd definitely welcome a PR!

@n0gu-furiosa
Copy link
Contributor Author

Thank you for your guidance @Rocketknight1! I've opened a PR for this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Feature request Request for a new feature
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants