diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java index 08b6aeb82..7358cdd21 100644 --- a/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java +++ b/src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java @@ -14,11 +14,20 @@ import com.google.gson.stream.JsonWriter; import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker; +import io.weaviate.client6.v1.api.collections.rerankers.JinaAiReranker; +import io.weaviate.client6.v1.api.collections.rerankers.NvidiaReranker; +import io.weaviate.client6.v1.api.collections.rerankers.TransformersReranker; +import io.weaviate.client6.v1.api.collections.rerankers.VoyageAiReranker; import io.weaviate.client6.v1.internal.ObjectBuilder; +import io.weaviate.client6.v1.internal.TaggedUnion; import io.weaviate.client6.v1.internal.json.JsonEnum; -public interface Reranker { +public interface Reranker extends TaggedUnion { public enum Kind implements JsonEnum { + JINAAI("reranker-jinaai"), + VOYAGEAI("reranker-voyageai"), + NVIDIA("reranker-nvidia"), + TRANSFORMERS("reranker-transformers"), COHERE("reranker-cohere"); private static final Map jsonValueMap = JsonEnum.collectNames(Kind.values()); @@ -38,10 +47,6 @@ public static Kind valueOfJson(String jsonValue) { } } - Kind _kind(); - - Object _self(); - /** Configure a default Cohere reranker module. */ public static Reranker cohere() { return CohereReranker.of(); @@ -56,6 +61,53 @@ public static Reranker cohere(Function> fn) { + return JinaAiReranker.of(fn); + } + + /** Configure a default VoyageAI reranker module. */ + public static Reranker voyageai() { + return VoyageAiReranker.of(); + } + + /** + * Configure a VoyageAI reranker module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Reranker voyageai(Function> fn) { + return VoyageAiReranker.of(fn); + } + + /** Configure a default Nvidia reranker module. */ + public static Reranker nvidia() { + return NvidiaReranker.of(); + } + + /** + * Configure a Nvidia reranker module. + * + * @param fn Lambda expression for optional parameters. + */ + public static Reranker nvidia(Function> fn) { + return NvidiaReranker.of(fn); + } + + /** Configure a default Transformers reranker module. */ + public static Reranker transformers() { + return new TransformersReranker(); + } + public static enum CustomTypeAdapterFactory implements TypeAdapterFactory { INSTANCE; diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/JinaAiReranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/JinaAiReranker.java new file mode 100644 index 000000000..7c7996c3c --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/JinaAiReranker.java @@ -0,0 +1,54 @@ +package io.weaviate.client6.v1.api.collections.rerankers; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Reranker; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record JinaAiReranker( + @SerializedName("model") String model) implements Reranker { + + public static final String BASE_MULTILINGUAL_V1 = "jina-reranker-v2-base-multilingual"; + public static final String BASE_ENGLISH_V1 = "jina-reranker-v1-base-en"; + public static final String TURBO_ENGLISH_V1 = "jina-reranker-v1-turbo-en"; + public static final String TINY_ENGLISH_V1 = "jina-reranker-v1-tiny-en"; + public static final String COLBERT_ENGLISH_V1 = "jina-colbert-v1-en"; + + @Override + public Kind _kind() { + return Reranker.Kind.JINAAI; + } + + @Override + public Object _self() { + return this; + } + + public static JinaAiReranker of() { + return of(ObjectBuilder.identity()); + } + + public static JinaAiReranker of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public JinaAiReranker(Builder builder) { + this(builder.model); + } + + public static class Builder implements ObjectBuilder { + private String model; + + public Builder model(String model) { + this.model = model; + return this; + } + + @Override + public JinaAiReranker build() { + return new JinaAiReranker(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/NvidiaReranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/NvidiaReranker.java new file mode 100644 index 000000000..3ed849745 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/NvidiaReranker.java @@ -0,0 +1,55 @@ +package io.weaviate.client6.v1.api.collections.rerankers; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Reranker; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record NvidiaReranker( + @SerializedName("model") String model, + @SerializedName("baseUrl") String baseUrl) implements Reranker { + + @Override + public Kind _kind() { + return Reranker.Kind.NVIDIA; + } + + @Override + public Object _self() { + return this; + } + + public static NvidiaReranker of() { + return of(ObjectBuilder.identity()); + } + + public static NvidiaReranker of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public NvidiaReranker(Builder builder) { + this(builder.model, builder.baseUrl); + } + + public static class Builder implements ObjectBuilder { + private String model; + private String baseUrl; + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + @Override + public NvidiaReranker build() { + return new NvidiaReranker(this); + } + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/TransformersReranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/TransformersReranker.java new file mode 100644 index 000000000..a1dccfc78 --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/TransformersReranker.java @@ -0,0 +1,16 @@ +package io.weaviate.client6.v1.api.collections.rerankers; + +import io.weaviate.client6.v1.api.collections.Reranker; + +public record TransformersReranker() implements Reranker { + + @Override + public Kind _kind() { + return Reranker.Kind.NVIDIA; + } + + @Override + public Object _self() { + return this; + } +} diff --git a/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/VoyageAiReranker.java b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/VoyageAiReranker.java new file mode 100644 index 000000000..bb1c3757f --- /dev/null +++ b/src/main/java/io/weaviate/client6/v1/api/collections/rerankers/VoyageAiReranker.java @@ -0,0 +1,53 @@ +package io.weaviate.client6.v1.api.collections.rerankers; + +import java.util.function.Function; + +import com.google.gson.annotations.SerializedName; + +import io.weaviate.client6.v1.api.collections.Reranker; +import io.weaviate.client6.v1.internal.ObjectBuilder; + +public record VoyageAiReranker( + @SerializedName("model") String model) implements Reranker { + + public static final String RERANK_1 = "rerank-1"; + public static final String RERANK_LITE_1 = "rerank-lite-1"; + public static final String RERANK_2 = "rerank-2"; + public static final String RERANK_LITE_2 = "rerank-2-lite"; + + @Override + public Kind _kind() { + return Reranker.Kind.VOYAGEAI; + } + + @Override + public Object _self() { + return this; + } + + public static VoyageAiReranker of() { + return of(ObjectBuilder.identity()); + } + + public static VoyageAiReranker of(Function> fn) { + return fn.apply(new Builder()).build(); + } + + public VoyageAiReranker(Builder builder) { + this(builder.model); + } + + public static class Builder implements ObjectBuilder { + private String model; + + public Builder model(String model) { + this.model = model; + return this; + } + + @Override + public VoyageAiReranker build() { + return new VoyageAiReranker(this); + } + } +}