Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chroma.vectorstore.ChromaApi.QueryRequest.Include;
import org.springframework.ai.chroma.vectorstore.common.ChromaApiConstants;
import org.springframework.ai.util.json.JsonParser;
Expand All @@ -52,6 +55,8 @@
*/
public class ChromaApi {

private static final Logger logger = LoggerFactory.getLogger(ChromaApi.class);

public static Builder builder() {
return new Builder();
}
Expand All @@ -62,6 +67,9 @@ public static Builder builder() {
// Regular expression pattern that looks for a message.
private static final Pattern MESSAGE_ERROR_PATTERN = Pattern.compile("\"message\":\"(.*?)\"");

// Regular expression pattern that looks for NotFoundError in JSON error response.
private static final Pattern NOT_FOUND_ERROR_PATTERN = Pattern.compile("NotFoundError\\('([^']*)'\\)");

private static final String X_CHROMA_TOKEN_NAME = "x-chroma-token";

private final ObjectMapper objectMapper;
Expand Down Expand Up @@ -136,12 +144,21 @@ public Tenant getTenant(String tenantName) {
.retrieve()
.body(Tenant.class);
}
catch (HttpServerErrorException | HttpClientErrorException e) {
String msg = this.getErrorMessage(e);
if (String.format("Tenant [%s] not found", tenantName).equals(msg)) {
catch (HttpClientErrorException e) {
if (isNotFoundError(e, "Tenant", tenantName)) {
String errorMessage = this.getErrorMessage(e);
if (StringUtils.hasText(errorMessage)) {
logger.debug("Tenant [{}] does not exist: {}, returning null", tenantName, errorMessage);
}
else {
logger.debug("Tenant [{}] does not exist, returning null", tenantName);
}
return null;
}
throw new RuntimeException(msg, e);
throw new RuntimeException(this.getErrorMessage(e), e);
}
catch (HttpServerErrorException e) {
throw new RuntimeException(this.getErrorMessage(e), e);
}
}

Expand All @@ -165,12 +182,23 @@ public Database getDatabase(String tenantName, String databaseName) {
.retrieve()
.body(Database.class);
}
catch (HttpServerErrorException | HttpClientErrorException e) {
String msg = this.getErrorMessage(e);
if (msg.startsWith(String.format("Database [%s] not found.", databaseName))) {
catch (HttpClientErrorException e) {
if (isNotFoundError(e, "Database", databaseName)) {
String errorMessage = this.getErrorMessage(e);
if (StringUtils.hasText(errorMessage)) {
logger.debug("Database [{}] in tenant [{}] does not exist: {}, returning null", databaseName,
tenantName, errorMessage);
}
else {
logger.debug("Database [{}] in tenant [{}] does not exist, returning null", databaseName,
tenantName);
}
return null;
}
throw new RuntimeException(msg, e);
throw new RuntimeException(this.getErrorMessage(e), e);
}
catch (HttpServerErrorException e) {
throw new RuntimeException(this.getErrorMessage(e), e);
}
}

Expand Down Expand Up @@ -226,12 +254,23 @@ public Collection getCollection(String tenantName, String databaseName, String c
.retrieve()
.body(Collection.class);
}
catch (HttpServerErrorException | HttpClientErrorException e) {
String msg = this.getErrorMessage(e);
if (String.format("Collection [%s] does not exists", collectionName).equals(msg)) {
catch (HttpClientErrorException e) {
if (isNotFoundError(e, "Collection", collectionName)) {
String errorMessage = this.getErrorMessage(e);
if (StringUtils.hasText(errorMessage)) {
logger.debug("Collection [{}] in database [{}] and tenant [{}] does not exist: {}, returning null",
collectionName, databaseName, tenantName, errorMessage);
}
else {
logger.debug("Collection [{}] in database [{}] and tenant [{}] does not exist, returning null",
collectionName, databaseName, tenantName);
}
return null;
}
throw new RuntimeException(msg, e);
throw new RuntimeException(this.getErrorMessage(e), e);
}
catch (HttpServerErrorException e) {
throw new RuntimeException(this.getErrorMessage(e), e);
}
}

Expand Down Expand Up @@ -326,7 +365,76 @@ private void httpHeaders(HttpHeaders headers) {
}
}

private boolean isNotFoundError(HttpClientErrorException e, String resourceType, String resourceName) {
// First, check the response body for JSON error response
String responseBody = e.getResponseBodyAsString();
if (StringUtils.hasText(responseBody)) {
try {
ErrorResponse errorResponse = this.objectMapper.readValue(responseBody, ErrorResponse.class);
String error = errorResponse.error();
// Check for NotFoundError('message') format
if (NOT_FOUND_ERROR_PATTERN.matcher(error).find()) {
return true;
}
// Check for "Resource [name] not found." format
String expectedPattern = String.format("%s [%s] not found.", resourceType, resourceName);
if (error.startsWith(expectedPattern)) {
return true;
}
// Check if error contains "not found" (case insensitive)
if (error.toLowerCase().contains("not found")) {
return true;
}
}
catch (JsonProcessingException ex) {
// If JSON parsing fails, check the response body directly
String expectedPattern = String.format("%s [%s] not found.", resourceType, resourceName);
if (responseBody.contains(expectedPattern) || responseBody.toLowerCase().contains("not found")) {
return true;
}
}
}

// Check the exception message
String errorMessage = e.getMessage();
if (StringUtils.hasText(errorMessage)) {
// Check for "Resource [name] not found." format
String expectedPattern = String.format("%s [%s] not found.", resourceType, resourceName);
if (errorMessage.contains(expectedPattern)) {
return true;
}
// Check if error message contains "not found" (case insensitive)
if (errorMessage.toLowerCase().contains("not found")) {
return true;
}
}

return false;
}

private String getErrorMessage(HttpStatusCodeException e) {
// First, try to parse the response body as JSON error response
String responseBody = e.getResponseBodyAsString();
if (StringUtils.hasText(responseBody)) {
try {
ErrorResponse errorResponse = this.objectMapper.readValue(responseBody, ErrorResponse.class);
// Extract error message from NotFoundError('message') format
Matcher notFoundErrorMatcher = NOT_FOUND_ERROR_PATTERN.matcher(errorResponse.error());
if (notFoundErrorMatcher.find()) {
return notFoundErrorMatcher.group(1);
}
// Return the error as-is if it doesn't match NotFoundError pattern
return errorResponse.error();
}
catch (JsonProcessingException ex) {
// If JSON parsing fails, return the response body as-is if it's not empty
if (StringUtils.hasText(responseBody.trim())) {
return responseBody.trim();
}
logger.debug("Failed to parse error response as JSON, falling back to exception message", ex);
}
}

var errorMessage = e.getMessage();

// If the error message is empty or null, return an empty string
Expand Down Expand Up @@ -622,6 +730,15 @@ private static class CollectionList extends ArrayList<Collection> {

}

/**
* Chroma API error response.
*
* @param error The error message from Chroma API.
*/
@JsonInclude(JsonInclude.Include.NON_NULL)
private record ErrorResponse(@JsonProperty("error") String error) {
}

public static final class Builder {

private String baseUrl = ChromaApiConstants.DEFAULT_BASE_URL;
Expand Down
Loading