Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions src/main/java/io/weaviate/client6/v1/api/collections/Reranker.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,20 @@
import com.google.gson.stream.JsonWriter;

import io.weaviate.client6.v1.api.collections.rerankers.CohereReranker;
import io.weaviate.client6.v1.api.collections.rerankers.JinaAiReranker;
import io.weaviate.client6.v1.api.collections.rerankers.NvidiaReranker;
import io.weaviate.client6.v1.api.collections.rerankers.TransformersReranker;
import io.weaviate.client6.v1.api.collections.rerankers.VoyageAiReranker;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.TaggedUnion;
import io.weaviate.client6.v1.internal.json.JsonEnum;

public interface Reranker {
public interface Reranker extends TaggedUnion<Reranker.Kind, Object> {
public enum Kind implements JsonEnum<Kind> {
JINAAI("reranker-jinaai"),
VOYAGEAI("reranker-voyageai"),
NVIDIA("reranker-nvidia"),
TRANSFORMERS("reranker-transformers"),
COHERE("reranker-cohere");

private static final Map<String, Kind> jsonValueMap = JsonEnum.collectNames(Kind.values());
Expand All @@ -38,10 +47,6 @@ public static Kind valueOfJson(String jsonValue) {
}
}

Kind _kind();

Object _self();

/** Configure a default Cohere reranker module. */
public static Reranker cohere() {
return CohereReranker.of();
Expand All @@ -56,6 +61,53 @@ public static Reranker cohere(Function<CohereReranker.Builder, ObjectBuilder<Coh
return CohereReranker.of(fn);
}

/** Configure a default JinaAI reranker module. */
public static Reranker jinaai() {
return JinaAiReranker.of();
}

/**
* Configure a JinaAI reranker module.
*
* @param fn Lambda expression for optional parameters.
*/
public static Reranker jinaai(Function<JinaAiReranker.Builder, ObjectBuilder<JinaAiReranker>> fn) {
return JinaAiReranker.of(fn);
}

/** Configure a default VoyageAI reranker module. */
public static Reranker voyageai() {
return VoyageAiReranker.of();
}

/**
* Configure a VoyageAI reranker module.
*
* @param fn Lambda expression for optional parameters.
*/
public static Reranker voyageai(Function<VoyageAiReranker.Builder, ObjectBuilder<VoyageAiReranker>> fn) {
return VoyageAiReranker.of(fn);
}

/** Configure a default Nvidia reranker module. */
public static Reranker nvidia() {
return NvidiaReranker.of();
}

/**
* Configure a Nvidia reranker module.
*
* @param fn Lambda expression for optional parameters.
*/
public static Reranker nvidia(Function<NvidiaReranker.Builder, ObjectBuilder<NvidiaReranker>> fn) {
return NvidiaReranker.of(fn);
}

/** Configure a default Transformers reranker module. */
public static Reranker transformers() {
return new TransformersReranker();
}

public static enum CustomTypeAdapterFactory implements TypeAdapterFactory {
INSTANCE;

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package io.weaviate.client6.v1.api.collections.rerankers;

import java.util.function.Function;

import com.google.gson.annotations.SerializedName;

import io.weaviate.client6.v1.api.collections.Reranker;
import io.weaviate.client6.v1.internal.ObjectBuilder;

public record JinaAiReranker(
@SerializedName("model") String model) implements Reranker {

public static final String BASE_MULTILINGUAL_V1 = "jina-reranker-v2-base-multilingual";
public static final String BASE_ENGLISH_V1 = "jina-reranker-v1-base-en";
public static final String TURBO_ENGLISH_V1 = "jina-reranker-v1-turbo-en";
public static final String TINY_ENGLISH_V1 = "jina-reranker-v1-tiny-en";
public static final String COLBERT_ENGLISH_V1 = "jina-colbert-v1-en";

@Override
public Kind _kind() {
return Reranker.Kind.JINAAI;
}

@Override
public Object _self() {
return this;
}

public static JinaAiReranker of() {
return of(ObjectBuilder.identity());
}

public static JinaAiReranker of(Function<Builder, ObjectBuilder<JinaAiReranker>> fn) {
return fn.apply(new Builder()).build();
}

public JinaAiReranker(Builder builder) {
this(builder.model);
}

public static class Builder implements ObjectBuilder<JinaAiReranker> {
private String model;

public Builder model(String model) {
this.model = model;
return this;
}

@Override
public JinaAiReranker build() {
return new JinaAiReranker(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.weaviate.client6.v1.api.collections.rerankers;

import java.util.function.Function;

import com.google.gson.annotations.SerializedName;

import io.weaviate.client6.v1.api.collections.Reranker;
import io.weaviate.client6.v1.internal.ObjectBuilder;

public record NvidiaReranker(
@SerializedName("model") String model,
@SerializedName("baseUrl") String baseUrl) implements Reranker {

@Override
public Kind _kind() {
return Reranker.Kind.NVIDIA;
}

@Override
public Object _self() {
return this;
}

public static NvidiaReranker of() {
return of(ObjectBuilder.identity());
}

public static NvidiaReranker of(Function<Builder, ObjectBuilder<NvidiaReranker>> fn) {
return fn.apply(new Builder()).build();
}

public NvidiaReranker(Builder builder) {
this(builder.model, builder.baseUrl);
}

public static class Builder implements ObjectBuilder<NvidiaReranker> {
private String model;
private String baseUrl;

public Builder model(String model) {
this.model = model;
return this;
}

public Builder baseUrl(String baseUrl) {
this.baseUrl = baseUrl;
return this;
}

@Override
public NvidiaReranker build() {
return new NvidiaReranker(this);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package io.weaviate.client6.v1.api.collections.rerankers;

import io.weaviate.client6.v1.api.collections.Reranker;

public record TransformersReranker() implements Reranker {

@Override
public Kind _kind() {
return Reranker.Kind.NVIDIA;
}

@Override
public Object _self() {
return this;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package io.weaviate.client6.v1.api.collections.rerankers;

import java.util.function.Function;

import com.google.gson.annotations.SerializedName;

import io.weaviate.client6.v1.api.collections.Reranker;
import io.weaviate.client6.v1.internal.ObjectBuilder;

public record VoyageAiReranker(
@SerializedName("model") String model) implements Reranker {

public static final String RERANK_1 = "rerank-1";
public static final String RERANK_LITE_1 = "rerank-lite-1";
public static final String RERANK_2 = "rerank-2";
public static final String RERANK_LITE_2 = "rerank-2-lite";

@Override
public Kind _kind() {
return Reranker.Kind.VOYAGEAI;
}

@Override
public Object _self() {
return this;
}

public static VoyageAiReranker of() {
return of(ObjectBuilder.identity());
}

public static VoyageAiReranker of(Function<Builder, ObjectBuilder<VoyageAiReranker>> fn) {
return fn.apply(new Builder()).build();
}

public VoyageAiReranker(Builder builder) {
this(builder.model);
}

public static class Builder implements ObjectBuilder<VoyageAiReranker> {
private String model;

public Builder model(String model) {
this.model = model;
return this;
}

@Override
public VoyageAiReranker build() {
return new VoyageAiReranker(this);
}
}
}