Skip to content

Commit

Permalink
added docs
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed May 14, 2024
1 parent 4b15179 commit a65d3e3
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 338 deletions.
163 changes: 152 additions & 11 deletions docs/asciidoc/modules/ROOT/pages/ml/openai.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,12 @@ overall, this graph database schema provides a simple yet powerful representatio
| value | the description of the dataset
|===

== Query with natural language
== Query with Retrieval-augmented generation (RAG) technique

This procedure `apoc.ml.rag` takes a natural language question and transforms it into a number of requested cypher queries.
This procedure `apoc.ml.rag` takes a list of paths or a vector index name, relevant attributes and a natural language question
to create a prompt implementing a Retrieval-augmented generation (RAG) technique.

TODO:
See https://aws.amazon.com/what-is/retrieval-augmented-generation/[here] for more info about the RAG process.

It uses the `chat/completions` API which is https://platform.openai.com/docs/api-reference/chat/create[documented here^].

Expand All @@ -525,30 +526,170 @@ It uses the `chat/completions` API which is https://platform.openai.com/docs/api
[%autowidth, opts=header]
|===
| name | description | mandatory
| cypher | The question in the natural language | yes
| paths | the list of paths to retrieve and augment the prompt, it can also be a matching query or a vector index name | yes
| attributes | the relevant attributes useful to retrieve and augment the prompt | yes
| question | the user question | yes
| conf | An optional configuration map, please check the next section | no
|===


.Configuration map
[%autowidth, opts=header]
|===
| name | description | mandatory
| retries | The number of retries in case of API call failures | no, default `3`
| getLabelTypes | add the label / rel-type names to the info to augment the prompt | no, default `true`
| embeddings | to search similar embeddings stored into a node vector index (in case of `embeddings: "NODE"`) or relationship vector index (in case of `embeddings: "REL"`) | no, default `"FALSE"`
| topK | number of neighbors to find for each node (in case of `embeddings: "NODE"`) or relationships (in case of `embeddings: "REL"`) | no, default `40`
| apiKey | OpenAI API key | in case `apoc.openai.key` is not defined
| model | The Open AI model | no, default `gpt-3.5-turbo`
| sample | The number of nodes to skip, e.g. a sample of 1000 will read every 1000th node. It's used as a parameter to `apoc.meta.data` procedure that computes the schema | no, default is a random number
|===


Using the apoc.ml.rag procedure we can reduce AI hallucinations (i.e. false or misleading responses),
providing relevant and up-to-date information to our procedure via the 1st parameter.

For example, by executing the following procedure (with the `gpt-3.5-turbo` model, last updated in January 2022)
we have a hallucination

.Query call
[source,cypher]
----
CALL apoc.ml.query("What movies did Tom Hanks play in?") yield value, query
RETURN *
CALL apoc.ml.openai.chat([
{role:"user", content: "Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?"}
], $apiKey)
----

.Example response
[opts="header"]
|===
| value
| TODO
|===
| The gold medal in curling at the 2022 Winter Olympics was won by the Swedish men's team and the Russian women's team.
|===

So, we can use the RAG technique to provide real results.
For example with the given dataset (with data taken from https://en.wikipedia.org/wiki/Curling_at_the_2022_Winter_Olympics[this wikipedia page]):

.wikipedia dataset
[source,cypher]
----
CREATE (mixed2022:Discipline {title:"Mixed doubles's curling", year: 2022})
WITH mixed2022
CREATE (:Athlete {name: 'Stefania Constantini', country: 'Italy', irrelevant: 'asdasd'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'asdasd'}]->(mixed2022)
CREATE (:Athlete {name: 'Amos Mosaner', country: 'Italy', irrelevant: 'qweqwe'})-[:HAS_MEDAL {medal: 'Gold', irrelevant2: 'rwerew'}]->(mixed2022)
CREATE (:Athlete {name: 'Kristin Skaslien', country: 'Norway', irrelevant: 'dfgdfg'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'gdfg'}]->(mixed2022)
CREATE (:Athlete {name: 'Magnus Nedregotten', country: 'Norway', irrelevant: 'xcvxcv'})-[:HAS_MEDAL {medal: 'Silver', irrelevant2: 'asdasd'}]->(mixed2022)
CREATE (:Athlete {name: 'Almida de Val', country: 'Sweden', irrelevant: 'rtyrty'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'bfbfb'}]->(mixed2022)
CREATE (:Athlete {name: 'Oskar Eriksson', country: 'Sweden', irrelevant: 'qwresdc'})-[:HAS_MEDAL {medal: 'Bronze', irrelevant2: 'juju'}]->(mixed2022)
----

we can execute:

.Query call
[source,cypher]
----
MATCH path=(:Athlete)-[:HAS_MEDAL]->(Discipline)
WITH collect(path) AS paths
CALL apoc.ml.rag(paths,
["name", "country", "medal", "title", "year"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey}
) YIELD value
RETURN value
----

.Example response
[opts="header"]
|===
| value
| The gold medal in curling at the 2022 Winter Olympics was won by Stefania Constantini and Amos Mosaner from Italy.
|===

or:

.Query call
[source,cypher]
----
MATCH path=(:Athlete)-[:HAS_MEDAL]->(Discipline)
WITH collect(path) AS paths
CALL apoc.ml.rag(paths,
["name", "country", "medal", "title", "year"],
"Which athletes won the silver medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey}
) YIELD value
RETURN value
----

.Example response
[opts="header"]
|===
| value
| The gold medal in curling at the 2022 Winter Olympics was won by Kristin Skaslien and Magnus Nedregotten from Norway.
|===

We can also pass a string query returning paths/relationships/nodes, for example:

[source,cypher]
----
CALL apoc.ml.rag("MATCH path=(:Athlete)-[:HAS_MEDAL]->(Discipline) WITH collect(path) AS paths",
["name", "country", "medal", "title", "year"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey}
) YIELD value
RETURN value
----

.Example response
[opts="header"]
|===
| value
| The gold medal in curling at the 2022 Winter Olympics was won by Stefania Constantini and Amos Mosaner from Italy.
|===

or we can pass a vector index name as the 1st parameter, in case we stored useful info into embedding nodes.
For example, given this node vector index:

[source,cypher]
----
CREATE VECTOR INDEX `rag-embeddings`
FOR (n:RagEmbedding) ON (n.embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1536,
`vector.similarity_function`: 'cosine'
}}
----

and some (:RagEmbedding) nodes with the `text` properties, we can execute:

[source,cypher]
----
CALL apoc.ml.rag("rag-embeddings",
["text"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey, embeddings: "NODE", topK: 20}
) YIELD value
RETURN value
----

or, with a relationship vector index:


[source,cypher]
----
CREATE VECTOR INDEX `rag-rel-embeddings`
FOR ()-[r:RAG_EMBEDDING]-() ON (r.embedding)
OPTIONS {indexConfig: {
`vector.dimensions`: 1536,
`vector.similarity_function`: 'cosine'
}}
----

and some [:RagEmbedding] relationships with the `text` properties, we can execute:

[source,cypher]
----
CALL apoc.ml.rag("rag-rel-embeddings",
["text"],
"Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?",
{apiKey: $apiKey, embeddings: "REL", topK: 20}
) YIELD value
RETURN value
----
31 changes: 18 additions & 13 deletions extended/src/main/java/apoc/ml/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public Stream<StringResult> rag(@Name("paths") Object paths,

RagConfig config = new RagConfig(conf);

String[] objects = attributes.toArray(String[]::new);
String[] arrayAttrs = attributes.toArray(String[]::new);

StringBuilder context = new StringBuilder();

Expand All @@ -69,42 +69,47 @@ public Stream<StringResult> rag(@Name("paths") Object paths,

for (var listItem : pathList) {
// -- Augment
extracted2(config, objects, context, listItem);
augment(config, arrayAttrs, context, listItem);
}

} else if (paths instanceof String queryOrIndex) {
config.getEmbedding()
config.getEmbeddings()
.getQuery(queryOrIndex, question, tx, config)
// -- Augment
.forEachRemaining(i -> i
.forEachRemaining(row -> row
.values()
.forEach( v -> extracted2(config, objects, context, v) )
// -- Augment
.forEach( val -> augment(config, arrayAttrs, context, val) )
);
} else {
throw new RuntimeException("The first parameter must be a List or a String");
}

// - Generate
String prompt = RAG_BASE_PROMPT.formatted(UNKNOWN_ANSWER, context);

String question1 = "\nQuestion:" + question;
String result = prompt(question1, prompt, null, null, conf, List.of());

String result = prompt("\nQuestion:" + question,
prompt,
null,
null,
conf,
List.of()
);
return Stream.of(new StringResult(result));
}

private static void extracted2(RagConfig config, String[] objects, StringBuilder context, Object listItem) {
private void augment(RagConfig config, String[] objects, StringBuilder context, Object listItem) {
if (listItem instanceof Path p) {
for (Entity entity : p) {
extracted(config, objects, context, entity);
augmentEntity(config, objects, context, entity);
}
} else if (listItem instanceof Entity e) {
extracted(config, objects, context, e);
augmentEntity(config, objects, context, e);
} else {
throw new RuntimeException("The list `%s` must have node/type/path items".formatted(listItem));
}
}

private static void extracted(RagConfig config, String[] objects, StringBuilder context, Entity entity) {
private void augmentEntity(RagConfig config, String[] objects, StringBuilder context, Entity entity) {
Map<String, Object> props = entity.getProperties(objects);
if (config.isGetLabelTypes()) {
String labelsOrType = entity instanceof Node node
Expand Down
8 changes: 4 additions & 4 deletions extended/src/main/java/apoc/ml/RagConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class RagConfig {
public static final String TOP_K_CONF = "topK";

private final boolean getLabelTypes;
private final EmbeddingQuery embedding;
private final EmbeddingQuery embeddings;
private final Integer topK;
private final String apiKey;
private final Map<String, Object> confMap;
Expand All @@ -27,7 +27,7 @@ public RagConfig(Map<String, Object> confMap) {
this.confMap = confMap;
this.getLabelTypes = Util.toBoolean(confMap.getOrDefault(GET_LABEL_TYPES_CONF, true));
String embeddingString = (String) confMap.getOrDefault(EMBEDDINGS_CONF, EmbeddingQuery.Type.FALSE.name());
this.embedding = EmbeddingQuery.Type.valueOf(embeddingString).get();
this.embeddings = EmbeddingQuery.Type.valueOf(embeddingString).get();
this.topK = Util.toInteger(confMap.getOrDefault(TOP_K_CONF, 40));
this.apiKey = (String) confMap.get(API_KEY_CONF);
}
Expand All @@ -36,8 +36,8 @@ public boolean isGetLabelTypes() {
return getLabelTypes;
}

public EmbeddingQuery getEmbedding() {
return embedding;
public EmbeddingQuery getEmbeddings() {
return embeddings;
}

public Integer getTopK() {
Expand Down
2 changes: 1 addition & 1 deletion extended/src/test/java/apoc/ml/PromptIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ public void testFromCypher() {

@Test
public void ragWithRelevantAttributesComparedToIrrelevantOneAndChatProcedure() {
String question = "Which athletes won the gold medal in curling at the 2022 Winter Olympics?";
String question = "Which athletes won the gold medal in mixed doubles's curling at the 2022 Winter Olympics?";

// -- test with hallucinations, wrong winner names
testCall(db, """
Expand Down
Loading

0 comments on commit a65d3e3

Please sign in to comment.