Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
vga91 committed May 10, 2024
1 parent ecbc228 commit fac0923
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions extended/src/main/java/apoc/ml/Prompt.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,64 @@
*/

/*
def num_tokens(text: str, model: str = GPT_MODEL) -> int:
"""Return the number of tokens in a string."""
encoding = tiktoken.encoding_for_model(model)
return len(encoding.encode(text))
def query_message(
query: str,
df: pd.DataFrame,
model: str,
token_budget: int
) -> str:
"""Return a message for GPT, with relevant source texts pulled from a dataframe."""
strings, relatednesses = strings_ranked_by_relatedness(query, df)
introduction = 'Use the below articles on the 2022 Winter Olympics to answer the subsequent question. If the answer cannot be found in the articles, write "I could not find an answer."'
question = f"\n\nQuestion: {query}"
message = introduction
for string in strings:
next_article = f'\n\nWikipedia article section:\n"""\n{string}\n"""'
if (
num_tokens(message + next_article + question, model=model)
> token_budget
):
break
else:
message += next_article
return message + question
def ask(
query: str,
df: pd.DataFrame = df,
model: str = GPT_MODEL,
token_budget: int = 4096 - 500,
print_message: bool = False,
) -> str:
"""Answers a query using GPT and a dataframe of relevant texts and embeddings."""
message = query_message(query, df, model=model, token_budget=token_budget)
if print_message:
print(message)
messages = [
{"role": "system", "content": "You answer questions about the 2022 Winter Olympics."},
{"role": "user", "content": message},
]
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=0
)
response_message = response.choices[0].message.content
return response_message
todo --> devo passare dei path che hanno delle proprietà interessanti..
test:
*/

@Extended
public class Prompt {

Expand All @@ -62,13 +120,30 @@ public class Prompt {
public ProcedureCallContext procedureCallContext;
@Context
public URLAccessChecker urlAccessChecker;


@Procedure(mode = Mode.READ)
@Description("Takes a query in cypher and in natural language and returns the results in natural language")
public Stream<StringResult> rag(@Name("cypher") String cypher,
@Name(value = "conf", defaultValue = "{}") Map<String, Object> conf) throws MalformedURLException, JsonProcessingException {
String schema = loadSchema(tx, conf);

String schemaExplanation = prompt("Please explain the graph database schema to me and relate it to well known concepts and domains.",
FROM_CYPHER_PROMPT, "This database schema ", schema, conf, List.of());
return Stream.of(new StringResult(schemaExplanation));
}


public static final String BACKTICKS = "```";
public static final String RAG_PROMPT = """
Use the below article on the 2022 Winter Olympics to answer the subsequent question. If the answer cannot be found, write "I don't know.
""";
public static final String EXPLAIN_SCHEMA_PROMPT = """
You are an expert in the Neo4j graph database and graph data modeling and have experience in a wide variety of business domains.
Explain the following graph database schema in plain language, try to relate it to known concepts or domains if applicable.
Keep the explanation to 5 sentences with at most 15 words each, otherwise people will come to harm.
""";


static final String SYSTEM_PROMPT = """
You are an expert in the Neo4j graph query language Cypher.
Expand Down

0 comments on commit fac0923

Please sign in to comment.