From 4583572426c793c283ff3bc8794c922eebc43993 Mon Sep 17 00:00:00 2001 From: Wael Karkoub Date: Thu, 11 Apr 2024 14:42:37 +0100 Subject: [PATCH] [Fix] Improves Token Limiter (#2350) * improves token limiter * improve docstr * rename arg --- .../contrib/capabilities/transforms.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/autogen/agentchat/contrib/capabilities/transforms.py b/autogen/agentchat/contrib/capabilities/transforms.py index d87963ec82e..f2ba6719118 100644 --- a/autogen/agentchat/contrib/capabilities/transforms.py +++ b/autogen/agentchat/contrib/capabilities/transforms.py @@ -85,7 +85,8 @@ class MessageTokenLimiter: 2. Individual messages are truncated based on max_tokens_per_message. For multimodal messages containing both text and other types of content, only the text content is truncated. 3. The overall conversation history is truncated based on the max_tokens limit. Once the accumulated token count - exceeds this limit, the current message being processed as well as any remaining messages are discarded. + exceeds this limit, the current message being processed get truncated to meet the total token count and any + remaining messages get discarded. 4. The truncated conversation history is reconstructed by prepending the messages to a new list to preserve the original message order. """ @@ -128,13 +129,20 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: total_tokens = sum(_count_tokens(msg["content"]) for msg in temp_messages) for msg in reversed(temp_messages): - msg["content"] = self._truncate_str_to_tokens(msg["content"]) - msg_tokens = _count_tokens(msg["content"]) + expected_tokens_remained = self._max_tokens - processed_messages_tokens - self._max_tokens_per_message - # If adding this message would exceed the token limit, discard it and all remaining messages - if processed_messages_tokens + msg_tokens > self._max_tokens: + # If adding this message would exceed the token limit, truncate the last message to meet the total token + # limit and discard all remaining messages + if expected_tokens_remained < 0: + msg["content"] = self._truncate_str_to_tokens( + msg["content"], self._max_tokens - processed_messages_tokens + ) + processed_messages.insert(0, msg) break + msg["content"] = self._truncate_str_to_tokens(msg["content"], self._max_tokens_per_message) + msg_tokens = _count_tokens(msg["content"]) + # prepend the message to the list to preserve order processed_messages_tokens += msg_tokens processed_messages.insert(0, msg) @@ -149,30 +157,30 @@ def apply_transform(self, messages: List[Dict]) -> List[Dict]: return processed_messages - def _truncate_str_to_tokens(self, contents: Union[str, List]) -> Union[str, List]: + def _truncate_str_to_tokens(self, contents: Union[str, List], n_tokens: int) -> Union[str, List]: if isinstance(contents, str): - return self._truncate_tokens(contents) + return self._truncate_tokens(contents, n_tokens) elif isinstance(contents, list): - return self._truncate_multimodal_text(contents) + return self._truncate_multimodal_text(contents, n_tokens) else: raise ValueError(f"Contents must be a string or a list of dictionaries. Received type: {type(contents)}") - def _truncate_multimodal_text(self, contents: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _truncate_multimodal_text(self, contents: List[Dict[str, Any]], n_tokens: int) -> List[Dict[str, Any]]: """Truncates text content within a list of multimodal elements, preserving the overall structure.""" tmp_contents = [] for content in contents: if content["type"] == "text": - truncated_text = self._truncate_tokens(content["text"]) + truncated_text = self._truncate_tokens(content["text"], n_tokens) tmp_contents.append({"type": "text", "text": truncated_text}) else: tmp_contents.append(content) return tmp_contents - def _truncate_tokens(self, text: str) -> str: + def _truncate_tokens(self, text: str, n_tokens: int) -> str: encoding = tiktoken.encoding_for_model(self._model) # Get the appropriate tokenizer encoded_tokens = encoding.encode(text) - truncated_tokens = encoded_tokens[: self._max_tokens_per_message] + truncated_tokens = encoded_tokens[:n_tokens] truncated_text = encoding.decode(truncated_tokens) # Decode back to text return truncated_text