Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public record Multi2VecGoogleVectorizer(
@SerializedName("imageFields") List<String> imageFields,
/** BLOB video properties included in the embedding. */
@SerializedName("videoFields") List<String> videoFields,
/** BLOB audio properties included in the embedding. */
@SerializedName("audioFields") List<String> audioFields,
/** TEXT properties included in the embedding. */
@SerializedName("textFields") List<String> textFields,
/** Weights of the included properties. */
Expand All @@ -43,6 +45,11 @@ private static record Weights(
* corresponding property names in {@code videoFields}.
*/
@SerializedName("videoWeights") List<Float> videoWeights,
/**
* Weights of the BLOB audio properties. Values appear in the same order as the
* corresponding property names in {@code audioFields}.
*/
@SerializedName("audioWeights") List<Float> audioWeights,
/**
* Weights of the TEXT properties. Values appear in the same order as the
* corresponding property names in {@code textFields}.
Expand Down Expand Up @@ -101,6 +108,7 @@ public Multi2VecGoogleVectorizer(
Integer videoIntervalSeconds,
List<String> imageFields,
List<String> videoFields,
List<String> audioFields,
List<String> textFields,
Weights weights,
VectorIndex vectorIndex,
Expand All @@ -114,6 +122,7 @@ public Multi2VecGoogleVectorizer(
this.videoIntervalSeconds = videoIntervalSeconds;
this.imageFields = imageFields;
this.videoFields = videoFields;
this.audioFields = audioFields;
this.textFields = textFields;
this.weights = weights;
this.vectorIndex = vectorIndex;
Expand All @@ -130,6 +139,7 @@ public Multi2VecGoogleVectorizer(Builder builder) {
builder.videoIntervalSeconds,
builder.imageFields,
builder.videoFields,
builder.audioFields,
builder.textFields,
builder.getWeights(),
builder.vectorIndex,
Expand All @@ -143,7 +153,9 @@ public static class Builder implements ObjectBuilder<Multi2VecGoogleVectorizer>
private List<String> imageFields;
private List<Float> imageWeights;
private List<String> videoFields;
private List<String> audioFields;
private List<Float> videoWeights;
private List<Float> audioWeights;
private List<String> textFields;
private List<Float> textWeights;

Expand Down Expand Up @@ -242,6 +254,35 @@ public Builder videoField(String field, float weight) {
return this;
}

/** Add BLOB audio properties to include in the embedding. */
public Builder audioFields(List<String> fields) {
this.audioFields = fields;
return this;
}

/** Add BLOB audio properties to include in the embedding. */
public Builder audioFields(String... fields) {
return audioFields(Arrays.asList(fields));
}

/**
* Add BLOB audio property to include in the embedding.
*
* @param field Property name.
* @param weight Custom weight between 0.0 and 1.0.
*/
public Builder audioField(String field, float weight) {
if (this.audioFields == null) {
this.audioFields = new ArrayList<>();
}
if (this.audioWeights == null) {
this.audioWeights = new ArrayList<>();
}
this.audioFields.add(field);
this.audioWeights.add(weight);
return this;
}

/** Add TEXT properties to include in the embedding. */
public Builder textFields(List<String> fields) {
this.textFields = fields;
Expand Down Expand Up @@ -272,8 +313,9 @@ public Builder textField(String field, float weight) {
}

protected Weights getWeights() {
if (this.textWeights != null || this.imageWeights != null || this.videoWeights != null) {
return new Weights(this.imageWeights, this.videoWeights, this.textWeights);
if (this.textWeights != null || this.imageWeights != null || this.videoWeights != null
|| this.audioWeights != null) {
return new Weights(this.imageWeights, this.videoWeights, this.audioWeights, this.textWeights);
}
return null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,8 @@ public static Object[][] testCases() {
.apiEndpoint("example.com")
.imageFields("a", "b")
.textFields("c")
.videoFields("d")),
.videoFields("d")
.audioFields("f")),
"""
{
"vectorIndexType": "hnsw",
Expand All @@ -804,7 +805,8 @@ public static Object[][] testCases() {
"location": "location",
"imageFields": ["a", "b"],
"textFields": ["c"],
"videoFields": ["d"]
"videoFields": ["d"],
"audioFields": ["f"]
}
}
}
Expand Down
Loading