diff --git a/spring-ai-commons/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java b/spring-ai-commons/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java index 1c1e7b9bc7a..84b828d995d 100644 --- a/spring-ai-commons/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java +++ b/spring-ai-commons/src/main/java/org/springframework/ai/tokenizer/JTokkitTokenCountEstimator.java @@ -16,6 +16,8 @@ package org.springframework.ai.tokenizer; +import java.util.Base64; + import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingType; @@ -34,18 +36,28 @@ */ public class JTokkitTokenCountEstimator implements TokenCountEstimator { + /** + * The JTokkit encoding instance used for token counting. + */ private final Encoding estimator; + /** + * Creates a new JTokkitTokenCountEstimator with default CL100K_BASE encoding. + */ public JTokkitTokenCountEstimator() { this(EncodingType.CL100K_BASE); } - public JTokkitTokenCountEstimator(EncodingType tokenEncodingType) { + /** + * Creates a new JTokkitTokenCountEstimator with the specified encoding type. + * @param tokenEncodingType the encoding type to use for token counting + */ + public JTokkitTokenCountEstimator(final EncodingType tokenEncodingType) { this.estimator = Encodings.newLazyEncodingRegistry().getEncoding(tokenEncodingType); } @Override - public int estimate(String text) { + public int estimate(final String text) { if (text == null) { return 0; } @@ -53,7 +65,7 @@ public int estimate(String text) { } @Override - public int estimate(MediaContent content) { + public int estimate(final MediaContent content) { int tokenCount = 0; if (content.getText() != null) { @@ -61,16 +73,15 @@ public int estimate(MediaContent content) { } if (!CollectionUtils.isEmpty(content.getMedia())) { - for (Media media : content.getMedia()) { - tokenCount += this.estimate(media.getMimeType().toString()); if (media.getData() instanceof String textData) { tokenCount += this.estimate(textData); } else if (media.getData() instanceof byte[] binaryData) { - tokenCount += binaryData.length; // This is likely incorrect. + String base64 = Base64.getEncoder().encodeToString(binaryData); + tokenCount += this.estimate(base64); } } } @@ -79,7 +90,7 @@ else if (media.getData() instanceof byte[] binaryData) { } @Override - public int estimate(Iterable contents) { + public int estimate(final Iterable contents) { int totalSize = 0; for (MediaContent mediaContent : contents) { totalSize += this.estimate(mediaContent);