Skip to content

Cache tiktoken encoding object #593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions libraries/python/openai-client/openai_client/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import math
import re
from fractions import Fraction
from functools import lru_cache
from io import BytesIO
from typing import Any, Iterable, Sequence

Expand Down Expand Up @@ -51,12 +52,23 @@ def resolve_model_name(model: str) -> str:
raise NotImplementedError(f"num_tokens_from_messages() is not implemented for model {model}.")


def get_encoding_for_model(model: str) -> tiktoken.Encoding:
@lru_cache(maxsize=16)
def _get_cached_encoding(resolved_model: str) -> tiktoken.Encoding:
"""Cache tiktoken encodings to avoid slow initialization on repeated calls."""
try:
return tiktoken.encoding_for_model(resolve_model_name(model))
return tiktoken.encoding_for_model(resolved_model)
except KeyError:
logger.warning(f"model {model} not found. Using cl100k_base encoding.")
return tiktoken.get_encoding("cl100k_base")
if resolved_model.startswith(("gpt-4o", "o")):
logger.warning(f"model {resolved_model} not found. Using o200k_base encoding.")
return tiktoken.get_encoding("o200k_base")
else:
logger.warning(f"model {resolved_model} not found. Using cl100k_base encoding.")
return tiktoken.get_encoding("cl100k_base")


def get_encoding_for_model(model: str) -> tiktoken.Encoding:
"""Get tiktoken encoding for a model, with caching for performance."""
return _get_cached_encoding(resolve_model_name(model))


def num_tokens_from_message(message: ChatCompletionMessageParam, model: str) -> int:
Expand Down Expand Up @@ -209,11 +221,7 @@ def num_tokens_from_tools(
f"num_tokens_from_tools_and_messages() is not implemented for model {specific_model}."
)

try:
encoding = tiktoken.encoding_for_model(specific_model)
except KeyError:
logger.warning("model %s not found. Using o200k_base encoding.", specific_model)
encoding = tiktoken.get_encoding("o200k_base")
encoding = _get_cached_encoding(specific_model)

token_count = 0
for f in tools:
Expand Down