Skip to content

Commit

Permalink
Adds generative openai support
Browse files Browse the repository at this point in the history
  • Loading branch information
aliszka committed Mar 8, 2023
1 parent f3d246e commit b75ff2a
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 8 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
WCS_DUMMY_CI_PW: ${{ secrets.WCS_DUMMY_CI_PW }}
OKTA_CLIENT_SECRET: ${{ secrets.OKTA_CLIENT_SECRET }}
AZURE_CLIENT_SECRET: ${{ secrets.AZURE_CLIENT_SECRET }}
OPENAI_APIKEY: ${{ secrets.OPENAI_APIKEY }}
run: |
docker-compose -f src/test/resources/docker-compose-test.yaml pull
mvn clean test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.Arrays;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
Expand All @@ -12,6 +13,7 @@
@Getter
@Builder
@ToString
@EqualsAndHashCode
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class Field implements Argument {
String name;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import java.util.Arrays;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.EqualsAndHashCode;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.ObjectUtils;
import technology.semi.weaviate.client.v1.graphql.query.argument.Argument;

import java.util.Arrays;
import java.util.stream.Collectors;

@Getter
@Builder
@ToString
@EqualsAndHashCode
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class Fields implements Argument {
Field[] fields;

@Override
public String build() {
if (this.fields != null && this.fields.length > 0) {
return StringUtils.joinWith(" ", Arrays.stream(this.fields).map(Field::build).toArray());
if (ObjectUtils.isEmpty(fields)) {
return "";
}
return "";
return Arrays.stream(fields)
.map(Field::build)
.collect(Collectors.joining(" "));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import org.apache.commons.lang3.StringUtils;

import java.util.LinkedHashSet;
import java.util.Set;

@Getter
@Builder
@ToString
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
public class GenerateFieldBuilder {

String singleResultPrompt;
String groupedResultTask;


public Field build() {
Set<String> nameParts = new LinkedHashSet<>();
Set<String> fieldNames = new LinkedHashSet<>();

if (StringUtils.isNotBlank(singleResultPrompt)) {
nameParts.add(String.format("singleResult:{prompt:\"\"\"%s\"\"\"}", singleResultPrompt));
fieldNames.add("singleResult");
}
if (StringUtils.isNotBlank(groupedResultTask)) {
nameParts.add(String.format("groupedResult:{task:\"\"\"%s\"\"\"}", groupedResultTask));
fieldNames.add("groupedResult");
}

if (nameParts.isEmpty()) {
return Field.builder().build();
}

fieldNames.add("error");
String name = String.format("generate(%s)", StringUtils.join(nameParts, " "));
Field[] fields = fieldNames.stream()
.map(n -> Field.builder().name(n).build())
.toArray(Field[]::new);

return Field.builder().name(name).fields(fields).build();
}

public Field buildWrappedByAdditional() {
return Field.builder()
.name("_additional")
.fields(new Field[]{ build() })
.build();
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package technology.semi.weaviate.client.v1.graphql.query.builder;

import junit.framework.TestCase;
import org.junit.Test;
import technology.semi.weaviate.client.v1.filters.Operator;
import technology.semi.weaviate.client.v1.filters.WhereFilter;
Expand All @@ -15,6 +14,7 @@
import technology.semi.weaviate.client.v1.graphql.query.argument.SortOrder;
import technology.semi.weaviate.client.v1.graphql.query.fields.Field;
import technology.semi.weaviate.client.v1.graphql.query.fields.Fields;
import technology.semi.weaviate.client.v1.graphql.query.fields.GenerateFieldBuilder;

import java.io.BufferedReader;
import java.io.File;
Expand All @@ -23,7 +23,11 @@
import java.io.InputStreamReader;
import java.util.stream.Collectors;

public class GetBuilderTest extends TestCase {
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

public class GetBuilderTest {

@Test
public void testBuildSimpleGet() {
Expand Down Expand Up @@ -412,4 +416,58 @@ public void testBuildGetWithSort() {
assertEquals("{Get{Pizza(sort:[{path:[\"property1\"]}, {path:[\"property2\"] order:desc}]){name}}}", query2);
assertEquals("{Get{Pizza(sort:[{path:[\"property1\"]}, {path:[\"property2\"] order:desc}, {path:[\"property3\"] order:asc}]){name}}}", query3);
}

@Test
public void shouldBuildGetWithMultipleFieldsIncludingGenerate() {
// given
Field name = Field.builder().name("name").build();
Field description = Field.builder().name("description").build();
Field _additional = Field.builder()
.name("_additional")
.fields(new Field[]{
Field.builder()
.name("id")
.build(),
GenerateFieldBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build().build(),
})
.build();
Fields fields = Fields.builder().fields(
new Field[]{ name, description, _additional }
).build();

// when
String query = GetBuilder.builder().className("Pizza").fields(fields).build().buildQuery();

// then
assertThat(query).isEqualTo("{Get{Pizza{name description _additional{id generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"})" +
"{singleResult groupedResult error}}}}}");
}

@Test
public void shouldBuildGetWithMultipleFieldsIncludingGenerateWrapped() {
// given
Field name = Field.builder().name("name").build();
Field description = Field.builder().name("description").build();
Field generateAdditional = GenerateFieldBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build().buildWrappedByAdditional();
Fields fields = Fields.builder().fields(
new Field[]{ name, description, generateAdditional }
).build();

// when
String query = GetBuilder.builder().className("Pizza").fields(fields).build().buildQuery();

// then
assertThat(query).isEqualTo("{Get{Pizza{name description _additional{generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"})" +
"{singleResult groupedResult error}}}}}");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package technology.semi.weaviate.client.v1.graphql.query.fields;

import org.junit.Test;

import static org.assertj.core.api.Assertions.assertThat;

public class GenerateFieldBuilderTest {

@Test
public void shouldBuildEmptyField() {
GenerateFieldBuilder generateFieldBuilder = GenerateFieldBuilder.builder()
.build();

Field generateAdditional = generateFieldBuilder.buildWrappedByAdditional();

assertThat(generateAdditional.getName()).isEqualTo("_additional");
assertThat(generateAdditional.getFields()).hasSize(1);

Field generate = generateFieldBuilder.build();

assertThat(generate).isEqualTo(generateAdditional.getFields()[0]);
assertThat(generate.getName()).isBlank();
assertThat(generate.getFields()).isNull();
}

@Test
public void shouldBuildSingleResultPromptField() {
GenerateFieldBuilder generateFieldBuilder = GenerateFieldBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.build();

Field generateAdditional = generateFieldBuilder.buildWrappedByAdditional();

assertThat(generateAdditional.getName()).isEqualTo("_additional");
assertThat(generateAdditional.getFields()).hasSize(1);

Field generate = generateFieldBuilder.build();

assertThat(generate).isEqualTo(generateAdditional.getFields()[0]);
assertThat(generate.getName()).isEqualTo("generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("singleResult", "error");
}

@Test
public void shouldBuildGroupedResultTaskField() {
GenerateFieldBuilder generateFieldBuilder = GenerateFieldBuilder.builder()
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build();

Field generateAdditional = generateFieldBuilder.buildWrappedByAdditional();

assertThat(generateAdditional.getName()).isEqualTo("_additional");
assertThat(generateAdditional.getFields()).hasSize(1);

Field generate = generateFieldBuilder.build();

assertThat(generate).isEqualTo(generateAdditional.getFields()[0]);
assertThat(generate.getName()).isEqualTo("generate(" +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("groupedResult", "error");
}

@Test
public void shouldBuildBothSingleResultPromptAndGroupedResultTaskField() {
GenerateFieldBuilder generateFieldBuilder = GenerateFieldBuilder.builder()
.singleResultPrompt("What is the meaning of life?")
.groupedResultTask("Explain why these magazines or newspapers are about finance")
.build();

Field generateAdditional = generateFieldBuilder.buildWrappedByAdditional();

assertThat(generateAdditional.getName()).isEqualTo("_additional");
assertThat(generateAdditional.getFields()).hasSize(1);

Field generate = generateFieldBuilder.build();

assertThat(generate).isEqualTo(generateAdditional.getFields()[0]);
assertThat(generate.getName()).isEqualTo("generate(" +
"singleResult:{prompt:\"\"\"What is the meaning of life?\"\"\"} " +
"groupedResult:{task:\"\"\"Explain why these magazines or newspapers are about finance\"\"\"}" +
")");
assertThat(generate.getFields()).extracting(Field::getName)
.containsExactly("singleResult", "groupedResult", "error");
}
}
Loading

0 comments on commit b75ff2a

Please sign in to comment.