Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

package org.springframework.ai.tokenizer;

import java.util.Base64;

import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingType;
Expand All @@ -34,43 +36,52 @@
*/
public class JTokkitTokenCountEstimator implements TokenCountEstimator {

/**
* The JTokkit encoding instance used for token counting.
*/
private final Encoding estimator;

/**
* Creates a new JTokkitTokenCountEstimator with default CL100K_BASE encoding.
*/
public JTokkitTokenCountEstimator() {
this(EncodingType.CL100K_BASE);
}

public JTokkitTokenCountEstimator(EncodingType tokenEncodingType) {
/**
* Creates a new JTokkitTokenCountEstimator with the specified encoding type.
* @param tokenEncodingType the encoding type to use for token counting
*/
public JTokkitTokenCountEstimator(final EncodingType tokenEncodingType) {
this.estimator = Encodings.newLazyEncodingRegistry().getEncoding(tokenEncodingType);
}

@Override
public int estimate(String text) {
public int estimate(final String text) {
if (text == null) {
return 0;
}
return this.estimator.countTokens(text);
}

@Override
public int estimate(MediaContent content) {
public int estimate(final MediaContent content) {
int tokenCount = 0;

if (content.getText() != null) {
tokenCount += this.estimate(content.getText());
}

if (!CollectionUtils.isEmpty(content.getMedia())) {

for (Media media : content.getMedia()) {

tokenCount += this.estimate(media.getMimeType().toString());

if (media.getData() instanceof String textData) {
tokenCount += this.estimate(textData);
}
else if (media.getData() instanceof byte[] binaryData) {
tokenCount += binaryData.length; // This is likely incorrect.
String base64 = Base64.getEncoder().encodeToString(binaryData);
tokenCount += this.estimate(base64);
}
}
}
Expand All @@ -79,7 +90,7 @@ else if (media.getData() instanceof byte[] binaryData) {
}

@Override
public int estimate(Iterable<MediaContent> contents) {
public int estimate(final Iterable<MediaContent> contents) {
int totalSize = 0;
for (MediaContent mediaContent : contents) {
totalSize += this.estimate(mediaContent);
Expand Down