Skip to content

Commit

Permalink
Merge pull request #187 from weaviate/generative_search_support
Browse files Browse the repository at this point in the history
Adds generative openai support
  • Loading branch information
antas-marcin committed Mar 13, 2023
2 parents dc3d36a + 3280b5f commit f0fffb2
Show file tree
Hide file tree
Showing 13 changed files with 406 additions and 21 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
WCS_DUMMY_CI_PW: ${{ secrets.WCS_DUMMY_CI_PW }}
OKTA_CLIENT_SECRET: ${{ secrets.OKTA_CLIENT_SECRET }}
AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }}
OPENAI_APIKEY: ${{ secrets.OPENAI_APIKEY }}
run: |
docker-compose -f src/test/resources/docker-compose-test.yaml pull
mvn clean test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import technology.semi.weaviate.client.v1.graphql.query.builder.GetBuilder;
import technology.semi.weaviate.client.v1.graphql.query.fields.Field;
import technology.semi.weaviate.client.v1.graphql.query.fields.Fields;
import technology.semi.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder;

public class Get extends BaseClient<GraphQLResponse> implements ClientResult<GraphQLResponse> {
private final GetBuilder.GetBuilderBuilder getBuilder;
Expand Down Expand Up @@ -106,6 +107,11 @@ public Get withSort(SortArgument... sort) {
return this;
}

public Get withGenerativeSearch(GenerativeSearchBuilder generativeSearch) {
this.getBuilder.withGenerativeSearch(generativeSearch);
return this;
}

@Override
public Result<GraphQLResponse> run() {
String getQuery = this.getBuilder.build().buildQuery();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,18 @@
import technology.semi.weaviate.client.v1.graphql.query.argument.NearTextArgument;
import technology.semi.weaviate.client.v1.graphql.query.argument.NearVectorArgument;
import technology.semi.weaviate.client.v1.graphql.query.argument.SortArguments;
import technology.semi.weaviate.client.v1.graphql.query.fields.Field;
import technology.semi.weaviate.client.v1.graphql.query.fields.Fields;
import technology.semi.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

@Getter
@Builder
Expand All @@ -43,11 +51,12 @@ public class GetBuilder implements Query {
NearVectorArgument withNearVectorFilter;
GroupArgument withGroupArgument;
SortArguments withSortArguments;
GenerativeSearchBuilder withGenerativeSearch;

private boolean includesFilterClause() {
return ObjectUtils.anyNotNull(withWhereFilter, withNearTextFilter, withNearObjectFilter,
withNearVectorFilter, withNearImageFilter, withGroupArgument, withAskArgument,withBm25Filter, withHybridFilter,
limit, offset, withSortArguments);
return ObjectUtils.anyNotNull(withWhereFilter, withNearTextFilter, withNearObjectFilter, withNearVectorFilter,
withNearImageFilter, withGroupArgument, withAskArgument, withBm25Filter, withHybridFilter, limit, offset,
withSortArguments);
}

private String createFilterClause() {
Expand Down Expand Up @@ -98,7 +107,54 @@ private String createFilterClause() {
}

private String createFields() {
return fields != null ? fields.build() : "";
if (ObjectUtils.allNull(fields, withGenerativeSearch)) {
return "";
}

if (withGenerativeSearch == null) {
return fields.build();
}

Field generate = withGenerativeSearch.build();
Field generateAdditional = Field.builder()
.name("_additional")
.fields(new Field[]{generate})
.build();

if (fields == null) {
return generateAdditional.build();
}

// check if _additional field exists. If missing just add new _additional with generate,
// if exists merge generate into present one
Map<Boolean, List<Field>> grouped = Arrays.stream(fields.getFields())
.collect(Collectors.groupingBy(f -> "_additional".equals(f.getName())));

List<Field> additionals = grouped.getOrDefault(true, new ArrayList<>());
if (additionals.isEmpty()) {
additionals.add(generateAdditional);
} else {
Field[] mergedInternalFields = Stream.concat(
Arrays.stream(additionals.get(0).getFields()),
Stream.of(generate)
).toArray(Field[]::new);

additionals.set(0, Field.builder()
.name("_additional")
.fields(mergedInternalFields)
.build()
);
}

Field[] allFields = Stream.concat(
grouped.getOrDefault(false, new ArrayList<>()).stream(),
additionals.stream()
).toArray(Field[]::new);

return Fields.builder()
.fields(allFields)
.build()
.build();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.Arrays;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
Expand All @@ -12,6 +13,7 @@
@Getter
@Builder
@ToString
@EqualsAndHashCode
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class Field implements Argument {
String name;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import java.util.Arrays;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.ObjectUtils;
import technology.semi.weaviate.client.v1.graphql.query.argument.Argument;

import java.util.Arrays;
import java.util.stream.Collectors;

@Getter
@Builder
@ToString
@EqualsAndHashCode
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class Fields implements Argument {
Field[] fields;

@Override
public String build() {
if (this.fields != null && this.fields.length > 0) {
return StringUtils.joinWith(" ", Arrays.stream(this.fields).map(Field::build).toArray());
if (ObjectUtils.isEmpty(fields)) {
return "";
}
return "";
return Arrays.stream(fields)
.map(Field::build)
.collect(Collectors.joining(" "));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.StringUtils;

import java.util.LinkedHashSet;
import java.util.Set;

@Getter
@Builder
@ToString
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GenerativeSearchBuilder {

String singleResultPrompt;
String groupedResultTask;


public Field build() {
Set<String> nameParts = new LinkedHashSet<>();
Set<String> fieldNames = new LinkedHashSet<>();

if (StringUtils.isNotBlank(singleResultPrompt)) {
nameParts.add(String.format("singleResult:{prompt:\"\"\"%s\"\"\"}", singleResultPrompt));
fieldNames.add("singleResult");
}
if (StringUtils.isNotBlank(groupedResultTask)) {
nameParts.add(String.format("groupedResult:{task:\"\"\"%s\"\"\"}", groupedResultTask));
fieldNames.add("groupedResult");
}

if (nameParts.isEmpty()) {
return Field.builder().build();
}

fieldNames.add("error");
String name = String.format("generate(%s)", StringUtils.join(nameParts, " "));
Field[] fields = fieldNames.stream()
.map(n -> Field.builder().name(n).build())
.toArray(Field[]::new);

return Field.builder().name(name).fields(fields).build();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package technology.semi.weaviate.client.v1.graphql.query.builder;

import junit.framework.TestCase;
import org.junit.Test;
import technology.semi.weaviate.client.v1.filters.Operator;
import technology.semi.weaviate.client.v1.filters.WhereFilter;
Expand All @@ -15,6 +14,7 @@
import technology.semi.weaviate.client.v1.graphql.query.argument.SortOrder;
import technology.semi.weaviate.client.v1.graphql.query.fields.Field;
import technology.semi.weaviate.client.v1.graphql.query.fields.Fields;
import technology.semi.weaviate.client.v1.graphql.query.fields.GenerativeSearchBuilder;

import java.io.BufferedReader;
import java.io.File;
Expand All @@ -23,7 +23,11 @@
import java.io.InputStreamReader;
import java.util.stream.Collectors;

public class GetBuilderTest extends TestCase {
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class GetBuilderTest {

@Test
public void testBuildSimpleGet() {
Expand Down Expand Up @@ -412,4 +416,61 @@ public void testBuildGetWithSort() {
assertEquals("{Get{Pizza(sort:[{path:[\"property1\"]}, {path:[\"property2\"] order:desc}]){name}}}", query2);
assertEquals("{Get{Pizza(sort:[{path:[\"property1\"]}, {path:[\"property2\"] order:desc}, {path:[\"property3\"] order:asc}]){name}}}", query3);
}

@Test
public void shouldBuildGetWithGenerativeSearchAndMultipleFieldsIncludingAdditional() {
// given
Fields fields = Fields.builder().fields(new Field[]{
Field.builder().name("name").build(),
Field.builder().name("description").build(),
Field.builder().name("_additional").fields(new Field[]{
Field.builder().name("id").build()
}).build()
}).build();

// when
String query = GetBuilder.builder()
.className("Pizza")
.fields(fields)
.withGenerativeSearch(
GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build()
)
.build().buildQuery();

// then
assertThat(query).isEqualTo("{Get{Pizza{name description _additional{id generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"})" +
"{singleResult groupedResult error}}}}}");
}

@Test
public void shouldBuildGetWithGenerativeSearchAndMultipleFields() {
// given
Fields fields = Fields.builder().fields(new Field[]{
Field.builder().name("name").build(),
Field.builder().name("description").build()
}).build();

// when
String query = GetBuilder.builder()
.className("Pizza")
.fields(fields)
.withGenerativeSearch(
GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build()
)
.build().buildQuery();

// then
assertThat(query).isEqualTo("{Get{Pizza{name description _additional{generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"})" +
"{singleResult groupedResult error}}}}}");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class GenerativeSearchBuilderTest {

@Test
public void shouldBuildEmptyField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isBlank();
assertThat(generate.getFields()).isNull();
}

@Test
public void shouldBuildSingleResultPromptField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isEqualTo("generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("singleResult", "error");
}

@Test
public void shouldBuildGroupedResultTaskField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isEqualTo("generate(" +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("groupedResult", "error");
}

@Test
public void shouldBuildBothSingleResultPromptAndGroupedResultTaskField() {
GenerativeSearchBuilder generativeSearchBuilder = GenerativeSearchBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build();

Field generate = generativeSearchBuilder.build();

assertThat(generate.getName()).isEqualTo("generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("singleResult", "groupedResult", "error");
}
}
Loading

0 comments on commit f0fffb2

Please sign in to comment.