Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds generative openai support #187

Merged
merged 2 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
WCS_DUMMY_CI_PW: ${{ secrets.WCS_DUMMY_CI_PW }}
OKTA_CLIENT_SECRET: ${{ secrets.OKTA_CLIENT_SECRET }}
AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }}
OPENAI_APIKEY: ${{ secrets.OPENAI_APIKEY }}
run: |
docker-compose -f src/test/resources/docker-compose-test.yaml pull
mvn clean test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import technology.semi.weaviate.client.v1.graphql.query.builder.GetBuilder;
import technology.semi.weaviate.client.v1.graphql.query.fields.Field;
import technology.semi.weaviate.client.v1.graphql.query.fields.Fields;
import technology.semi.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder;

public class Get extends BaseClient<GraphQLResponse> implements ClientResult<GraphQLResponse> {
private final GetBuilder.GetBuilderBuilder getBuilder;
Expand Down Expand Up @@ -106,6 +107,11 @@ public Get withSort(SortArgument... sort) {
return this;
}

public Get withGenerativeSearch(GenerativeSearchBuilder generativeSearch) {
this.getBuilder.withGenerativeSearch(generativeSearch);
return this;
}

@Override
public Result<GraphQLResponse> run() {
String getQuery = this.getBuilder.build().buildQuery();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@
import technology.semi.weaviate.client.v1.graphql.query.argument.NearTextArgument;
import technology.semi.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import technology.semi.weaviate.client.v1.graphql.query.argument.SortArguments;
import technology.semi.weaviate.client.v1.graphql.query.fields.Field;
import technology.semi.weaviate.client.v1.graphql.query.fields.Fields;
import technology.semi.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Getter
@Builder
Expand All @@ -43,11 +51,12 @@ public class GetBuilder implements Query {
NearVectorArgument withNearVectorFilter;
GroupArgument withGroupArgument;
SortArguments withSortArguments;
GenerativeSearchBuilder withGenerativeSearch;

private boolean includesFilterClause() {
return ObjectUtils.anyNotNull(withWhereFilter, withNearTextFilter, withNearObjectFilter,
withNearVectorFilter, withNearImageFilter, withGroupArgument, withAskArgument,withBm25Filter, withHybridFilter,
limit, offset, withSortArguments);
return ObjectUtils.anyNotNull(withWhereFilter, withNearTextFilter, withNearObjectFilter, withNearVectorFilter,
withNearImageFilter, withGroupArgument, withAskArgument, withBm25Filter, withHybridFilter, limit, offset,
withSortArguments);
}

private String createFilterClause() {
Expand Down Expand Up @@ -98,7 +107,54 @@ private String createFilterClause() {
}

private String createFields() {
return fields != null ? fields.build() : "";
if (ObjectUtils.allNull(fields, withGenerativeSearch)) {
return "";
}

if (withGenerativeSearch == null) {
return fields.build();
}

Field generate = withGenerativeSearch.build();
Field generateAdditional = Field.builder()
.name("_additional")
.fields(new Field[]{generate})
.build();

if (fields == null) {
return generateAdditional.build();
}

// check if _additional field exists. If missing just add new _additional with generate,
// if exists merge generate into present one
Map<Boolean, List<Field>> grouped = Arrays.stream(fields.getFields())
.collect(Collectors.groupingBy(f -> "_additional".equals(f.getName())));

List<Field> additionals = grouped.getOrDefault(true, new ArrayList<>());
if (additionals.isEmpty()) {
additionals.add(generateAdditional);
} else {
Field[] mergedInternalFields = Stream.concat(
Arrays.stream(additionals.get(0).getFields()),
Stream.of(generate)
).toArray(Field[]::new);

additionals.set(0, Field.builder()
.name("_additional")
.fields(mergedInternalFields)
.build()
);
}

Field[] allFields = Stream.concat(
grouped.getOrDefault(false, new ArrayList<>()).stream(),
additionals.stream()
).toArray(Field[]::new);

return Fields.builder()
.fields(allFields)
.build()
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.Arrays;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
Expand All @@ -12,6 +13,7 @@
@Getter
@Builder
@ToString
@EqualsAndHashCode
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class Field implements Argument {
String name;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import java.util.Arrays;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.ObjectUtils;
import technology.semi.weaviate.client.v1.graphql.query.argument.Argument;

import java.util.Arrays;
import java.util.stream.Collectors;

@Getter
@Builder
@ToString
@EqualsAndHashCode
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class Fields implements Argument {
Field[] fields;

@Override
public String build() {
if (this.fields != null && this.fields.length > 0) {
return StringUtils.joinWith(" ", Arrays.stream(this.fields).map(Field::build).toArray());
if (ObjectUtils.isEmpty(fields)) {
return "";
}
return "";
return Arrays.stream(fields)
.map(Field::build)
.collect(Collectors.joining(" "));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.StringUtils;

import java.util.LinkedHashSet;
import java.util.Set;

@Getter
@Builder
@ToString
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GenerativeSearchBuilder {

String singleResultPrompt;
String groupedResultTask;


public Field build() {
Set<String> nameParts = new LinkedHashSet<>();
Set<String> fieldNames = new LinkedHashSet<>();

if (StringUtils.isNotBlank(singleResultPrompt)) {
nameParts.add(String.format("singleResult:{prompt:\"\"\"%s\"\"\"}", singleResultPrompt));
fieldNames.add("singleResult");
}
if (StringUtils.isNotBlank(groupedResultTask)) {
nameParts.add(String.format("groupedResult:{task:\"\"\"%s\"\"\"}", groupedResultTask));
fieldNames.add("groupedResult");
}

if (nameParts.isEmpty()) {
return Field.builder().build();
}

fieldNames.add("error");
String name = String.format("generate(%s)", StringUtils.join(nameParts, " "));
Field[] fields = fieldNames.stream()
.map(n -> Field.builder().name(n).build())
.toArray(Field[]::new);

return Field.builder().name(name).fields(fields).build();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package technology.semi.weaviate.client.v1.graphql.query.builder;

import junit.framework.TestCase;
import org.junit.Test;
import technology.semi.weaviate.client.v1.filters.Operator;
import technology.semi.weaviate.client.v1.filters.WhereFilter;
Expand All @@ -15,6 +14,7 @@
import technology.semi.weaviate.client.v1.graphql.query.argument.SortOrder;
import technology.semi.weaviate.client.v1.graphql.query.fields.Field;
import technology.semi.weaviate.client.v1.graphql.query.fields.Fields;
import technology.semi.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder;

import java.io.BufferedReader;
import java.io.File;
Expand All @@ -23,7 +23,11 @@
import java.io.InputStreamReader;
import java.util.stream.Collectors;

public class GetBuilderTest extends TestCase {
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class GetBuilderTest {

@Test
public void testBuildSimpleGet() {
Expand Down Expand Up @@ -412,4 +416,61 @@ public void testBuildGetWithSort() {
assertEquals("{Get{Pizza(sort:[{path:[\"property1\"]}, {path:[\"property2\"] order:desc}]){name}}}", query2);
assertEquals("{Get{Pizza(sort:[{path:[\"property1\"]}, {path:[\"property2\"] order:desc}, {path:[\"property3\"] order:asc}]){name}}}", query3);
}

@Test
public void shouldBuildGetWithGenerativeSearchAndMultipleFieldsIncludingAdditional() {
// given
Fields fields = Fields.builder().fields(new Field[]{
Field.builder().name("name").build(),
Field.builder().name("description").build(),
Field.builder().name("_additional").fields(new Field[]{
Field.builder().name("id").build()
}).build()
}).build();

// when
String query = GetBuilder.builder()
.className("Pizza")
.fields(fields)
.withGenerativeSearch(
GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build()
)
.build().buildQuery();

// then
assertThat(query).isEqualTo("{Get{Pizza{name description _additional{id generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"})" +
"{singleResult groupedResult error}}}}}");
}

@Test
public void shouldBuildGetWithGenerativeSearchAndMultipleFields() {
// given
Fields fields = Fields.builder().fields(new Field[]{
Field.builder().name("name").build(),
Field.builder().name("description").build()
}).build();

// when
String query = GetBuilder.builder()
.className("Pizza")
.fields(fields)
.withGenerativeSearch(
GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build()
)
.build().buildQuery();

// then
assertThat(query).isEqualTo("{Get{Pizza{name description _additional{generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"})" +
"{singleResult groupedResult error}}}}}");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class GenerativeSearchBuilderTest {

@Test
public void shouldBuildEmptyField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isBlank();
assertThat(generate.getFields()).isNull();
}

@Test
public void shouldBuildSingleResultPromptField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isEqualTo("generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("singleResult", "error");
}

@Test
public void shouldBuildGroupedResultTaskField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isEqualTo("generate(" +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("groupedResult", "error");
}

@Test
public void shouldBuildBothSingleResultPromptAndGroupedResultTaskField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isEqualTo("generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("singleResult", "groupedResult", "error");
}
}
Loading