Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vanna trulens performance metrics #238

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
428 changes: 428 additions & 0 deletions examples/app_evaluation.py

Large diffs are not rendered by default.

25 changes: 25 additions & 0 deletions examples/evaluations.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
## Evaluation of Vanna using TruLens Eval

1. [Install Trulens package](https://www.trulens.org/trulens_eval/install/).
2. In `evaluation/app_evaluation.py`, Configure the metrics to use, the test data (example user questions).
3. Run the script, and open dashboard at http://localhost:8501 to view results as they're processed.


### Compare multiple versions of the Vanna app
![alt text](images/trulens-2.png)

### See performance per-run of `vanna.generate_sql()`
![alt text](images/trulens-1.png)

### Overview of test results across multiple metrics
![alt text](images/trulens-3.png)

### Examining the inputs & outputs of context relevance metrics

- **Agreement to truth**: the response SQL is compared to the ground truth SQL.
- **Groundedness**: How much overlap there is in meaning, between a context item and the response SQL
- **sql**: The final SQL, and a question-sql pair retrieved from the vector store. An average of all example SQL-question pair scores (this is using just n=1)
- **ddl**: The final SQL, and a DDL pair retrieved from the vector store.
- **document**: The final SQL, and a document retrieved from the vector store.

![alt text](images/trulens-4.png)
Binary file added examples/images/trulens-1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/trulens-2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/trulens-3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/images/trulens-4.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
252 changes: 186 additions & 66 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def generate_sql(self, question: str, **kwargs) -> str:
doc_list=doc_list,
**kwargs,
)
self.log(prompt)
llm_response = self.submit_prompt(prompt, **kwargs)
self.log(llm_response)
return self.extract_sql(llm_response)

def extract_sql(self, llm_response: str) -> str:
Expand Down Expand Up @@ -133,6 +135,68 @@ def remove_training_data(id: str, **kwargs) -> bool:
# ----------------- Use Any Language Model API ----------------- #

@abstractmethod
def system_message(self, message: str) -> any:
pass

@abstractmethod
def user_message(self, message: str) -> any:
pass

@abstractmethod
def assistant_message(self, message: str) -> any:
pass

def str_to_approx_token_count(self, string: str) -> int:
return len(string) / 4

def add_ddl_to_prompt(
self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000
) -> str:
if len(ddl_list) > 0:
initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for ddl in ddl_list:
if (
self.str_to_approx_token_count(initial_prompt)
+ self.str_to_approx_token_count(ddl)
< max_tokens
):
initial_prompt += f"{ddl}\n\n"

return initial_prompt

def add_documentation_to_prompt(
self, initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000
) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for documentation in documentation_list:
if (
self.str_to_approx_token_count(initial_prompt)
+ self.str_to_approx_token_count(documentation)
< max_tokens
):
initial_prompt += f"{documentation}\n\n"

return initial_prompt

def add_sql_to_prompt(
self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000
) -> str:
if len(sql_list) > 0:
initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"

for question in sql_list:
if (
self.str_to_approx_token_count(initial_prompt)
+ self.str_to_approx_token_count(question["sql"])
< max_tokens
):
initial_prompt += f"{question['question']}\n{question['sql']}\n\n"

return initial_prompt

def get_sql_prompt(
self,
question: str,
Expand All @@ -141,32 +205,125 @@ def get_sql_prompt(
doc_list: list,
**kwargs,
):
pass
initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n"

initial_prompt = self.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=14000
)

initial_prompt = self.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)

message_log = [self.system_message(initial_prompt)]

for example in question_sql_list:
if example is None:
print("example is None")
else:
if example is not None and "question" in example and "sql" in example:
message_log.append(self.user_message(example["question"]))
message_log.append(self.assistant_message(example["sql"]))

message_log.append(self.user_message(question))

return message_log

@abstractmethod
def get_followup_questions_prompt(
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs,
):
pass
) -> list:
initial_prompt = f"The user initially asked the question: '{question}': \n\n"

initial_prompt = self.add_ddl_to_prompt(
initial_prompt, ddl_list, max_tokens=14000
)

initial_prompt = self.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)

initial_prompt = self.add_sql_to_prompt(
initial_prompt, question_sql_list, max_tokens=14000
)

message_log = [self.system_message(initial_prompt)]
message_log.append(
self.user_message(
"Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions."
)
)

return message_log

@abstractmethod
def submit_prompt(self, prompt, **kwargs) -> str:
pass

@abstractmethod
def generate_question(self, sql: str, **kwargs) -> str:
pass
response = self.submit_prompt(
[
self.system_message(
"The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
),
self.user_message(sql),
],
**kwargs,
)

return response

def _extract_python_code(self, markdown_string: str) -> str:
# Regex pattern to match Python code blocks
pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```"

# Find all matches in the markdown string
matches = re.findall(pattern, markdown_string, re.IGNORECASE)

# Extract the Python code from the matches
python_code = []
for match in matches:
python = match[0] if match[0] else match[1]
python_code.append(python.strip())

if len(python_code) == 0:
return markdown_string

return python_code[0]

def _sanitize_plotly_code(self, raw_plotly_code: str) -> str:
# Remove the fig.show() statement from the plotly code
plotly_code = raw_plotly_code.replace("fig.show()", "")

return plotly_code

@abstractmethod
def generate_plotly_code(
self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs
) -> str:
pass
if question is not None:
system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'"
else:
system_msg = "The following is a pandas DataFrame "

if sql is not None:
system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n"

system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}"

message_log = [
self.system_message(system_msg),
self.user_message(
"Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code."
),
]

plotly_code = self.submit_prompt(message_log, kwargs=kwargs)

return self._sanitize_plotly_code(self._extract_python_code(plotly_code))

# ----------------- Connect to Any Database to run the Generated SQL ----------------- #

Expand Down Expand Up @@ -469,6 +626,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:

self.run_sql_is_set = True
self.run_sql = run_sql_bigquery

def connect_to_duckdb(self, url: str, init_sql: str = None):
"""
Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql]
Expand Down Expand Up @@ -514,9 +672,10 @@ def run_sql_duckdb(sql: str):

self.run_sql = run_sql_duckdb
self.run_sql_is_set = True

def run_sql(sql: str, **kwargs) -> pd.DataFrame:
raise NotImplementedError(
"You need to connect_to_snowflake or other database first."
"You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_sqlite(), vn.connect_to_postgres(), vn.connect_to_bigquery(), or vn.connect_to_duckdb()"
)

def ask(
Expand All @@ -533,6 +692,24 @@ def ask(
],
None,
]:
"""
**Example:**
```python
vn.ask("What are the top 10 customers by sales?")
```

Ask Vanna.AI a question and get the SQL query that answers it.

Args:
question (str): The question to ask.
print_results (bool): Whether to print the results of the SQL query.
auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query.
visualize (bool): Whether to generate plotly code and display the plotly figure.

Returns:
Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure.
"""

if question is None:
question = input("Enter a question: ")

Expand Down Expand Up @@ -927,60 +1104,3 @@ def get_plotly_figure(

return fig


class SplitStorage(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)

def get_similar_question_sql(self, embedding: str, **kwargs) -> list:
question_sql_ids = self.get_similar_question_sql_ids(embedding, **kwargs)
question_sql_list = self.get_question_sql(question_sql_ids, **kwargs)
return question_sql_list

def get_related_ddl(self, embedding: str, **kwargs) -> list:
ddl_ids = self.get_related_ddl_ids(embedding, **kwargs)
ddl_list = self.get_ddl(ddl_ids, **kwargs)
return ddl_list

def get_related_documentation(self, embedding: str, **kwargs) -> list:
doc_ids = self.get_related_documentation_ids(embedding, **kwargs)
doc_list = self.get_documentation(doc_ids, **kwargs)
return doc_list

# ----------------- Use Any Vector Database to Store and Lookup Embedding Similarity ----------------- #
@abstractmethod
def store_question_sql_embedding(self, embedding: str, **kwargs) -> str:
pass

@abstractmethod
def store_ddl_embedding(self, embedding: str, **kwargs) -> str:
pass

@abstractmethod
def store_documentation_embedding(self, embedding: str, **kwargs) -> str:
pass

@abstractmethod
def get_similar_question_sql_ids(self, embedding: str, **kwargs) -> list:
pass

@abstractmethod
def get_related_ddl_ids(self, embedding: str, **kwargs) -> list:
pass

@abstractmethod
def get_related_documentation_ids(self, embedding: str, **kwargs) -> list:
pass

# ----------------- Use Database to Retrieve the Documents from ID Lists ----------------- #
@abstractmethod
def get_question_sql(self, question_sql_ids: list, **kwargs) -> list:
pass

@abstractmethod
def get_documentation(self, doc_ids: list, **kwargs) -> list:
pass

@abstractmethod
def get_ddl(self, ddl_ids: list, **kwargs) -> list:
pass
2 changes: 1 addition & 1 deletion src/vanna/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_training_data():
{
"type": "df",
"id": "training_data",
"df": df.tail(25).to_json(orient="records"),
"df": df.to_json(orient="records"),
}
)

Expand Down
Loading