diff --git a/examples/app_evaluation.py b/examples/app_evaluation.py new file mode 100644 index 00000000..a1a5017d --- /dev/null +++ b/examples/app_evaluation.py @@ -0,0 +1,428 @@ +import json +from datetime import date +from typing import Callable, Dict, Iterable, List, Optional, Set + +import numpy as np +import requests +import rich +import sqlparse +from pydantic import Field +from sklearn.metrics import hamming_loss, jaccard_score + +# pip install "trulens-eval==0.19.2" +# pip install "litellm=1.21.7" +from trulens_eval import Feedback, LiteLLM, Select, Tru +from trulens_eval.feedback import Groundedness, GroundTruthAgreement, prompts +from trulens_eval.tru_custom_app import TruCustomApp, instrument +from trulens_eval.utils.generated import re_0_10_rating + +# From Vanna commit: e995493bcd189f3052c99ea8295c789a6de1aeea +from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore +from vanna.ollama import Ollama + +# Initialise Tru (see more https://www.trulens.org/trulens_eval/install/) +tru = Tru() +tru.reset_database() +tru.run_dashboard() + + +class OllamaLocalDB(ChromaDB_VectorStore, Ollama): + """Locally served LLM, and vector database""" + + def __init__(self, config: dict): + Ollama.__init__(self, config=config) + ChromaDB_VectorStore.__init__(self, config=config) + self.config: dict + self.call_log = [] + + def submit_prompt(self, prompt, **kwargs) -> str: + url = "http://localhost:11434/api/generate" + + payload = { + "model": self.config["model"].split("/")[-1], # or "llama2", + "prompt": prompt[0]["content"] + prompt[1]["content"], + "stream": False, + } + + payload_json = json.dumps(payload) + headers = {"Content-Type": "application/json"} + + response = requests.post(url, data=payload_json, headers=headers) + json_response = response.json() + self.call_log.append((payload_json, json_response)) + + return json_response["response"] + + +# Make sure these methods are tracked in the dashboard, and can have metrics for their results +instrument.method(OllamaLocalDB, "get_related_documentation") +instrument.method(OllamaLocalDB, "get_similar_question_sql") +instrument.method(OllamaLocalDB, "get_related_ddl") +instrument.method(OllamaLocalDB, "generate_sql") + + +class StandAloneProvider(LiteLLM): + """Inherits from LiteLLM to access the Ollama models. + + Adds functionality to evaluate Vanna-specific performance. + """ + + ground_truth_prompts: List[dict] + possible_table_names: Optional[List[str]] = Field(default=None) + table_name_indicies: Optional[Dict[str, int]] = Field(default=None) + dbpath: str + + def __init__(self, *args, **kwargs): + # TODO: why was self_kwargs required here independently of kwargs? + self_kwargs = dict() + self_kwargs.update(**kwargs) + # All database table names that the app has access to. + if self_kwargs["possible_table_names"] is not None: + self_kwargs["table_name_indicies"] = { + table_name: i + for i, table_name in enumerate(self_kwargs["possible_table_names"]) + } + + super().__init__(**self_kwargs) # need to include pydantic.BaseModel.__init__ + + def jacard_matching_tables(self, user_query, retrieved_tables): + """Measures the similarity between the predicted set of labels and the true set of labels. + It is calculated as the size of the intersection divided by the size of the union of the two label sets. + """ + return self.matching_tables(user_query, retrieved_tables, metric="jacard") + + def exact_match_matching_tables(self, user_query, retrieved_tables): + """Instances where the predicted labels exactly match the true labels. This must be performed on a per-instance + basis (with all predicted tables and all actual tables). + + Extracts the table names from the `NodeWithScore` objects. + """ + tables = [] + for retrieval in retrieved_tables: + tables.append(retrieval["node"]["metadata"]["name"]) + return self.matching_tables( + user_query, retrieved_tables=tables, metric="exact_match" + ) + + def matching_tables( + self, user_query: str, retrieved_tables: Iterable[str], metric="hamming_loss" + ) -> float: + """Multi label classification for a single instance. `metric`'s available: hamming_loss, jacard_similarity""" + if self.table_name_indicies is None: + raise ValueError("possible_table_names must be set") + + # get the first (and only) expected tables that matches the ground truth data + actual_tables = [ + data["tables"] + for data in self.ground_truth_prompts + if data["query"] == user_query + ][0] + + # Binary vectors to represent the tables + # create a binary valued vector from the possible table names indicies + y_pred = [int(t in retrieved_tables) for t in self.table_name_indicies] + y_true = [int(t in actual_tables) for t in self.table_name_indicies] + + if metric == "hamming_loss": + # Penalises equally false negatives (when a table didnt occur, but should have) + # and positives (when a table occured, but shouldnt have) + score = float(1 - hamming_loss(y_pred=y_pred, y_true=y_true)) + return score + + if metric == "jacard_similarity": + # Penalises missing tables that should have been there. But not if irrelevant tables were. + # Binary average assumes that this score is aggregate with a sum. + score = float(jaccard_score(y_pred=y_pred, y_true=y_true, average="binary")) + return score + + return float(set(y_pred) == set(y_true)) + + def table_match_factory(self, metric="hamming_loss") -> Callable: + """Factory for function that scores the match between tables found in sql query, and ground truth + expected tables for the query.""" + + def func(user_query: str, sql_result: str) -> float: + assert type(sql_result) == str, "type(sql_result) == str" + assert type(user_query) == str, "type(user_query) == str" + retrieved_tables = self.parse_table_names_from_sql(sql_result) + return self.matching_tables(user_query, retrieved_tables, metric=metric) + + return func + + def table_match( + self, user_query: str, sql_result: str, metric: str = "hamming_loss" + ): + + retrieved_tables = self.parse_table_names_from_sql(sql_result) + return self.matching_tables(user_query, retrieved_tables, metric=metric) + + @staticmethod + def parse_table_names_from_sql(sql: str) -> Set[str]: + """ + Extracts table names from the given SQL query and returns a set of unique table names. + + Args: + sql (str): The SQL query from which to extract table names. + + Returns: + Set[str]: A set containing unique table names extracted from the SQL query. + """ + # TODO: unit tests - Some SQL languages might behave differently + parsed = sqlparse.parse(sql) + table_names = set() + for stmt in parsed: + for token in stmt.tokens: + # Check if the token is an Identifier and likely a table name + if ( + isinstance(token, sqlparse.sql.Identifier) + and token.get_real_name().isidentifier() + ): + table_names.add(token.get_real_name()) + + return table_names + + def _qs_relevance(self, question: str, statement: str) -> str: + # Borrowed from super().qs_relevance() + return self.endpoint.run_me( + lambda: self._create_chat_completion( + prompt=str.format( + prompts.QS_RELEVANCE, question=question, statement=statement + ) + ) + ) + + def query_sql_relevance(self, question: str, statement: dict) -> float: + """Judge the relevance of the example SQL-question pair, on the users input query""" + json_sql_question = statement + statement_str: str = ( + f"Another user asked a question: '{json_sql_question['question']}', and the SQL statement used to answer it was: '{json_sql_question['sql']}'" + ) + + llm_response = self._qs_relevance(question, statement_str) + # Using rich for nicer printing + rich.print( + f"Running qs_relevance on inputs;" + f":: question = [bold]{question}[/bold]" + f":: statement = [bold]{statement_str}[/bold]\n" + f"GPT judgement: '[bold]{llm_response}'[/bold]" + f"score: [bold]{re_0_10_rating(llm_response)}[/bold]\n" + ) + + return re_0_10_rating(llm_response) / 10 + + def qs_relevance(self, question: str, statement: str) -> float: + """""" + llm_response = self._qs_relevance(question, statement) + # Using rich for nicer printing + rich.print( + f"Running qs_relevance on inputs;" + f":: question = [bold]{question}[/bold]" + f":: statement = [bold]{statement}[/bold]\n" + f"GPT judgement: '[bold]{llm_response}'[/bold]" + f"score: [bold]{re_0_10_rating(llm_response)}[/bold]\n" + ) + + return re_0_10_rating(llm_response) / 10 + + +def _load_metrics(prompts: List[dict], config: dict) -> List[Feedback]: + """Creates evaluation metrics for the TruLens recorder.""" + + # A Evaluation model customised for sql table matching evaluation + provider = StandAloneProvider( + ground_truth_prompts=prompts, + model_engine=config["evaluation"]["model"], + possible_table_names=config["database"]["table_names"], + dbpath=config["database"]["path"], + ) + + # How well the response agrees with the known to be true response. + ground_truth_collection = GroundTruthAgreement( + ground_truth=prompts, provider=provider + ) + f_groundtruth_agreement_measure = Feedback( + ground_truth_collection.agreement_measure, name="Agreement-to-truth measure" + ).on_input_output() + # Note: the above could have a newly synthesized response from a superior model in place of the ground truth; + # This is different to query-statement relevance. This takes the apps response to a query, + # and compares that response to a newly generated response to the same question. So, no ground + # truth data is used. Only checking whether the app performs similarly to another independent + # LLM. Useful to give confidence that the app is as-performant as an independent "SOTA model" + + # For evaluating each retrieved context + f_qs_relevance_documentation = ( + Feedback(provider.qs_relevance, name="Query-Documentation Relevance") + .on( + question=Select.RecordCalls.get_related_documentation.args.question, + ) + .on(statement=Select.RecordCalls.get_related_documentation.rets[:]) + .aggregate(np.mean) + ) + f_qs_relevance_ddl = ( + Feedback(provider.qs_relevance, name="Query-DDL Relevance") + .on( + question=Select.RecordCalls.get_related_ddl.args.question, + ) + .on(statement=Select.RecordCalls.get_related_ddl.rets[:]) + .aggregate(np.mean) + ) + f_qs_relevance_sql = ( + Feedback(provider.query_sql_relevance, name="Query-SQL Relevance") + .on( + question=Select.RecordCalls.get_similar_question_sql.args.question, + ) + .on(statement=Select.RecordCalls.get_similar_question_sql.rets[:]) + .aggregate(np.mean) + ) + context_relevance_metrics = [ + f_qs_relevance_documentation, + f_qs_relevance_ddl, + f_qs_relevance_sql, + ] + + # For checking if retrieved context is relevant to the response (sql queries, table schemas, DDL or documentation). + # it looks for information overlap between retrieved documents, and the llms response. + grounded = Groundedness(provider) + f_groundedness_sql = ( + Feedback(grounded.groundedness_measure, name="groundedness_sql") + .on(statement=Select.RecordCalls.get_similar_question_sql.rets[:]) + .on_output() + .aggregate(grounded.grounded_statements_aggregator) + ) + f_groundedness_ddl = ( + Feedback(grounded.groundedness_measure, name="groundedness_ddl") + .on(Select.RecordCalls.get_related_ddl.rets[:]) + .on_output() + .aggregate(grounded.grounded_statements_aggregator) + ) + f_groundedness_document = ( + Feedback(grounded.groundedness_measure, name="groundedness_document") + .on(Select.RecordCalls.get_related_documentation.rets[:]) + .on_output() + .aggregate(grounded.grounded_statements_aggregator) + ) + + retrieval_metrics = [] + for metric in config["evaluation"]["retrieval"]["metrics"]: + F = provider.table_match_factory(metric=metric) + retrieval_metrics.append( + Feedback( + F, name=f"Matching Tables: {metric.replace('_', ' ').title()}" + ).on_input_output() + ) + + # Note: some metrics use the LLM to perform the scoring. Consumes tokens / cost. + return [ + f_groundtruth_agreement_measure, + f_groundedness_ddl, + f_groundedness_sql, + f_groundedness_document, + *context_relevance_metrics, + *retrieval_metrics, + ] + + +def init_vanna_training(vn: ChromaDB_VectorStore, config): + """Adds context training examples to the vector store.""" + + assert all( + [ + vn.reomove_collection(collection_name=name) + for name in ["sql", "ddl", "documentation"] + ] + ) + vn.connect_to_sqlite(config["database"]["path"]) + + vn.train( + sql="""SELECT FirstName from employees where FirstName LIKE 'A%';""", + question="List the employees whos first name starts with 'A'", + ) + vn.train( + documentation="The 'employees' table contains the FirstName and LastName of employees" + ) + vn.train( + ddl="""( + [EmployeeId] INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL, + [LastName] NVARCHAR(20) NOT NULL, + [FirstName] NVARCHAR(20) NOT NULL, + [Title] NVARCHAR(30), + [ReportsTo] INTEGER, + [BirthDate] DATETIME, + [HireDate] DATETIME, + FOREIGN KEY ([ReportsTo]) REFERENCES "employees" ([EmployeeId]) + ON DELETE NO ACTION ON UPDATE NO ACTION)""" + ) + + +def run(vn: OllamaLocalDB, prompts: List[dict], config: dict, app_id=None): + """creates metrics to evaluate the vanna pipeline, instruments the app with them, then runs some test prompts.""" + + evaluation_metrics = _load_metrics(prompts, config) + tru_recorder = TruCustomApp( + vn, + feedbacks=evaluation_metrics, + tru=tru, + app_id=app_id if app_id else f"{config['model']}", + ) + + for i, prompt in enumerate(prompts): + response, record = tru_recorder.with_record(vn.generate_sql, prompt["query"]) + # manually add costs, since Ollama & Litellm served locally doesnt integrate cost & token counting + record.cost.n_prompt_tokens = vn.call_log[i][1]["prompt_eval_count"] + record.cost.n_completion_tokens = vn.call_log[i][1]["eval_count"] + record.cost.n_tokens = ( + record.cost.n_completion_tokens + record.cost.n_prompt_tokens + ) + record.cost.cost = 0.99 # Example cost of inference + + tru_recorder.tru.add_record(record) + + +if __name__ == "__main__": + + # Configures both the vanna app, and the evaluation pipeline + config = { # Vector store location + "path": f"./_vectorstore/{date.today()}", + "ollama_host": "http://localhost:11434", + # Ollama must be initialised with models. Specify which one to use here + "model": "mistralreranker", + "database": { + "path": "/mnt/c/Users/ssch7/repos/db-chat-assistant/data/chinook.db", + # list all table names here, so response SQL calls can be checked against them. + "table_names": ["employees", "artists", "customers"], + }, + "evaluation": { + "retrieval": {"metrics": ["hamming_loss"]}, + "model": "ollama/mistralreranker", + }, + } + + # Evaluate the app with These prompts and their known ground truth answers. + test_prompts = [ + dict( + query="List the top twelve employees", + sql="SELECT FirstName, LastName from employees LIMIT 12;", + tables=["employees"], + prompt="List the top twelve employees", + ), + dict( + query="Who were the employees by first name, who served customers with first names starting with 'A'?", + sql="""SELECT e.FirstName FROM employees e JOIN customers c ON e.EmployeeId = c.SupportRepId WHERE c.FirstName LIKE 'A%';""", + tables=["employees", "customers"], + prompt="Who were the employees by first name, who served customers with first names starting with 'A'?", + ), + ] + + # Note: Alternatively, download more test data from; + # wget https://github.com/jkkummerfeld/text2sql-data/blob/master/data/restaurants.json + # wget https://github.com/jkkummerfeld/text2sql-data/blob/master/data/restaurants-schema.csv + + vn = OllamaLocalDB(config=config) + init_vanna_training(vn, config) + run(vn, test_prompts, config, app_id="Mistral 7B : OllamaLocalDB") + + # Add a challenger app, that uses a different LLM + config["model"] = "llama2reranker" + vn = OllamaLocalDB(config=config) + init_vanna_training(vn, config) + run(vn, test_prompts, config, app_id="Llama2 7B : OllamaLocalDB") diff --git a/examples/evaluations.md b/examples/evaluations.md new file mode 100644 index 00000000..21bb81b9 --- /dev/null +++ b/examples/evaluations.md @@ -0,0 +1,25 @@ +## Evaluation of Vanna using TruLens Eval + +1. [Install Trulens package](https://www.trulens.org/trulens_eval/install/). +2. In `evaluation/app_evaluation.py`, Configure the metrics to use, the test data (example user questions). +3. Run the script, and open dashboard at http://localhost:8501 to view results as they're processed. + + +### Compare multiple versions of the Vanna app +![alt text](images/trulens-2.png) + +### See performance per-run of `vanna.generate_sql()` +![alt text](images/trulens-1.png) + +### Overview of test results across multiple metrics +![alt text](images/trulens-3.png) + +### Examining the inputs & outputs of context relevance metrics + +- **Agreement to truth**: the response SQL is compared to the ground truth SQL. +- **Groundedness**: How much overlap there is in meaning, between a context item and the response SQL + - **sql**: The final SQL, and a question-sql pair retrieved from the vector store. An average of all example SQL-question pair scores (this is using just n=1) + - **ddl**: The final SQL, and a DDL pair retrieved from the vector store. + - **document**: The final SQL, and a document retrieved from the vector store. + +![alt text](images/trulens-4.png) diff --git a/examples/images/trulens-1.png b/examples/images/trulens-1.png new file mode 100644 index 00000000..48bd0eb2 Binary files /dev/null and b/examples/images/trulens-1.png differ diff --git a/examples/images/trulens-2.png b/examples/images/trulens-2.png new file mode 100644 index 00000000..a04d29e6 Binary files /dev/null and b/examples/images/trulens-2.png differ diff --git a/examples/images/trulens-3.png b/examples/images/trulens-3.png new file mode 100644 index 00000000..96e273f9 Binary files /dev/null and b/examples/images/trulens-3.png differ diff --git a/examples/images/trulens-4.png b/examples/images/trulens-4.png new file mode 100644 index 00000000..76ee12bf Binary files /dev/null and b/examples/images/trulens-4.png differ diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index ee1754c0..d2a6091a 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -37,7 +37,9 @@ def generate_sql(self, question: str, **kwargs) -> str: doc_list=doc_list, **kwargs, ) + self.log(prompt) llm_response = self.submit_prompt(prompt, **kwargs) + self.log(llm_response) return self.extract_sql(llm_response) def extract_sql(self, llm_response: str) -> str: @@ -133,6 +135,68 @@ def remove_training_data(id: str, **kwargs) -> bool: # ----------------- Use Any Language Model API ----------------- # @abstractmethod + def system_message(self, message: str) -> any: + pass + + @abstractmethod + def user_message(self, message: str) -> any: + pass + + @abstractmethod + def assistant_message(self, message: str) -> any: + pass + + def str_to_approx_token_count(self, string: str) -> int: + return len(string) / 4 + + def add_ddl_to_prompt( + self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000 + ) -> str: + if len(ddl_list) > 0: + initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for ddl in ddl_list: + if ( + self.str_to_approx_token_count(initial_prompt) + + self.str_to_approx_token_count(ddl) + < max_tokens + ): + initial_prompt += f"{ddl}\n\n" + + return initial_prompt + + def add_documentation_to_prompt( + self, initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000 + ) -> str: + if len(documentation_list) > 0: + initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for documentation in documentation_list: + if ( + self.str_to_approx_token_count(initial_prompt) + + self.str_to_approx_token_count(documentation) + < max_tokens + ): + initial_prompt += f"{documentation}\n\n" + + return initial_prompt + + def add_sql_to_prompt( + self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000 + ) -> str: + if len(sql_list) > 0: + initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for question in sql_list: + if ( + self.str_to_approx_token_count(initial_prompt) + + self.str_to_approx_token_count(question["sql"]) + < max_tokens + ): + initial_prompt += f"{question['question']}\n{question['sql']}\n\n" + + return initial_prompt + def get_sql_prompt( self, question: str, @@ -141,9 +205,30 @@ def get_sql_prompt( doc_list: list, **kwargs, ): - pass + initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" + + initial_prompt = self.add_ddl_to_prompt( + initial_prompt, ddl_list, max_tokens=14000 + ) + + initial_prompt = self.add_documentation_to_prompt( + initial_prompt, doc_list, max_tokens=14000 + ) + + message_log = [self.system_message(initial_prompt)] + + for example in question_sql_list: + if example is None: + print("example is None") + else: + if example is not None and "question" in example and "sql" in example: + message_log.append(self.user_message(example["question"])) + message_log.append(self.assistant_message(example["sql"])) + + message_log.append(self.user_message(question)) + + return message_log - @abstractmethod def get_followup_questions_prompt( self, question: str, @@ -151,22 +236,94 @@ def get_followup_questions_prompt( ddl_list: list, doc_list: list, **kwargs, - ): - pass + ) -> list: + initial_prompt = f"The user initially asked the question: '{question}': \n\n" + + initial_prompt = self.add_ddl_to_prompt( + initial_prompt, ddl_list, max_tokens=14000 + ) + + initial_prompt = self.add_documentation_to_prompt( + initial_prompt, doc_list, max_tokens=14000 + ) + + initial_prompt = self.add_sql_to_prompt( + initial_prompt, question_sql_list, max_tokens=14000 + ) + + message_log = [self.system_message(initial_prompt)] + message_log.append( + self.user_message( + "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions." + ) + ) + + return message_log @abstractmethod def submit_prompt(self, prompt, **kwargs) -> str: pass - @abstractmethod def generate_question(self, sql: str, **kwargs) -> str: - pass + response = self.submit_prompt( + [ + self.system_message( + "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." + ), + self.user_message(sql), + ], + **kwargs, + ) + + return response + + def _extract_python_code(self, markdown_string: str) -> str: + # Regex pattern to match Python code blocks + pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" + + # Find all matches in the markdown string + matches = re.findall(pattern, markdown_string, re.IGNORECASE) + + # Extract the Python code from the matches + python_code = [] + for match in matches: + python = match[0] if match[0] else match[1] + python_code.append(python.strip()) + + if len(python_code) == 0: + return markdown_string + + return python_code[0] + + def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: + # Remove the fig.show() statement from the plotly code + plotly_code = raw_plotly_code.replace("fig.show()", "") + + return plotly_code - @abstractmethod def generate_plotly_code( self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs ) -> str: - pass + if question is not None: + system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" + else: + system_msg = "The following is a pandas DataFrame " + + if sql is not None: + system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" + + system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" + + message_log = [ + self.system_message(system_msg), + self.user_message( + "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." + ), + ] + + plotly_code = self.submit_prompt(message_log, kwargs=kwargs) + + return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) # ----------------- Connect to Any Database to run the Generated SQL ----------------- # @@ -469,6 +626,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_bigquery + def connect_to_duckdb(self, url: str, init_sql: str = None): """ Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] @@ -514,9 +672,10 @@ def run_sql_duckdb(sql: str): self.run_sql = run_sql_duckdb self.run_sql_is_set = True + def run_sql(sql: str, **kwargs) -> pd.DataFrame: raise NotImplementedError( - "You need to connect_to_snowflake or other database first." + "You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_sqlite(), vn.connect_to_postgres(), vn.connect_to_bigquery(), or vn.connect_to_duckdb()" ) def ask( @@ -533,6 +692,24 @@ def ask( ], None, ]: + """ + **Example:** + ```python + vn.ask("What are the top 10 customers by sales?") + ``` + + Ask Vanna.AI a question and get the SQL query that answers it. + + Args: + question (str): The question to ask. + print_results (bool): Whether to print the results of the SQL query. + auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query. + visualize (bool): Whether to generate plotly code and display the plotly figure. + + Returns: + Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure. + """ + if question is None: question = input("Enter a question: ") @@ -927,60 +1104,3 @@ def get_plotly_figure( return fig - -class SplitStorage(VannaBase): - def __init__(self, config=None): - VannaBase.__init__(self, config=config) - - def get_similar_question_sql(self, embedding: str, **kwargs) -> list: - question_sql_ids = self.get_similar_question_sql_ids(embedding, **kwargs) - question_sql_list = self.get_question_sql(question_sql_ids, **kwargs) - return question_sql_list - - def get_related_ddl(self, embedding: str, **kwargs) -> list: - ddl_ids = self.get_related_ddl_ids(embedding, **kwargs) - ddl_list = self.get_ddl(ddl_ids, **kwargs) - return ddl_list - - def get_related_documentation(self, embedding: str, **kwargs) -> list: - doc_ids = self.get_related_documentation_ids(embedding, **kwargs) - doc_list = self.get_documentation(doc_ids, **kwargs) - return doc_list - - # ----------------- Use Any Vector Database to Store and Lookup Embedding Similarity ----------------- # - @abstractmethod - def store_question_sql_embedding(self, embedding: str, **kwargs) -> str: - pass - - @abstractmethod - def store_ddl_embedding(self, embedding: str, **kwargs) -> str: - pass - - @abstractmethod - def store_documentation_embedding(self, embedding: str, **kwargs) -> str: - pass - - @abstractmethod - def get_similar_question_sql_ids(self, embedding: str, **kwargs) -> list: - pass - - @abstractmethod - def get_related_ddl_ids(self, embedding: str, **kwargs) -> list: - pass - - @abstractmethod - def get_related_documentation_ids(self, embedding: str, **kwargs) -> list: - pass - - # ----------------- Use Database to Retrieve the Documents from ID Lists ----------------- # - @abstractmethod - def get_question_sql(self, question_sql_ids: list, **kwargs) -> list: - pass - - @abstractmethod - def get_documentation(self, doc_ids: list, **kwargs) -> list: - pass - - @abstractmethod - def get_ddl(self, ddl_ids: list, **kwargs) -> list: - pass diff --git a/src/vanna/flask.py b/src/vanna/flask.py index 57bec28f..5177ceaf 100644 --- a/src/vanna/flask.py +++ b/src/vanna/flask.py @@ -251,7 +251,7 @@ def get_training_data(): { "type": "df", "id": "training_data", - "df": df.tail(25).to_json(orient="records"), + "df": df.to_json(orient="records"), } ) diff --git a/src/vanna/mistral/mistral.py b/src/vanna/mistral/mistral.py index 8acccfcd..627c46e7 100644 --- a/src/vanna/mistral/mistral.py +++ b/src/vanna/mistral/mistral.py @@ -19,149 +19,15 @@ def __init__(self, config=None): self.client = MistralClient(api_key=api_key) self.model = model - def _extract_python_code(self, markdown_string: str) -> str: - # Regex pattern to match Python code blocks - pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" - - # Find all matches in the markdown string - matches = re.findall(pattern, markdown_string, re.IGNORECASE) - - # Extract the Python code from the matches - python_code = [] - for match in matches: - python = match[0] if match[0] else match[1] - python_code.append(python.strip()) - - if len(python_code) == 0: - return markdown_string - - return python_code[0] - - def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: - # Remove the fig.show() statement from the plotly code - plotly_code = raw_plotly_code.replace("fig.show()", "") - - return plotly_code - - def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str: - if question is not None: - system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" - else: - system_msg = "The following is a pandas DataFrame " - - if sql is not None: - system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" - - system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" - - message_log = [ - self.system_message(system_msg), - self.user_message( - "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." - ), - ] - - plotly_code = self.submit_prompt(message_log, kwargs=kwargs) - - return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) - - def generate_question(self, sql: str, **kwargs) -> str: - response = self.submit_prompt( - [ - self.system_message( - "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." - ), - self.user_message(sql), - ], - **kwargs, - ) - - return response - - def get_followup_questions_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs): - initial_prompt = f"The user initially asked the question: '{question}': \n\n" - - initial_prompt = Mistral.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) - - initial_prompt = Mistral.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) - - initial_prompt = Mistral.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000) - - message_log = [Mistral.system_message(initial_prompt)] - message_log.append(Mistral.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions.")) - - return message_log - - @staticmethod - def system_message(message: str) -> dict: + def system_message(self, message: str) -> any: return ChatMessage(role="system", content=message) - @staticmethod - def user_message(message: str) -> dict: + def user_message(self, message: str) -> any: return ChatMessage(role="user", content=message) - @staticmethod - def assistant_message(message: str) -> dict: + def assistant_message(self, message: str) -> any: return ChatMessage(role="assistant", content=message) - @staticmethod - def str_to_approx_token_count(string: str) -> int: - return len(string) / 4 - - @staticmethod - def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str: - if len(ddl_list) > 0: - initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for ddl in ddl_list: - if Mistral.str_to_approx_token_count(initial_prompt) + Mistral.str_to_approx_token_count(ddl) < max_tokens: - initial_prompt += f"{ddl}\n\n" - - return initial_prompt - - @staticmethod - def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str: - if len(documentation_list) > 0: - initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for documentation in documentation_list: - if Mistral.str_to_approx_token_count(initial_prompt) + Mistral.str_to_approx_token_count(documentation) < max_tokens: - initial_prompt += f"{documentation}\n\n" - - return initial_prompt - - @staticmethod - def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str: - if len(sql_list) > 0: - initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for question in sql_list: - if Mistral.str_to_approx_token_count(initial_prompt) + Mistral.str_to_approx_token_count(question["sql"]) < max_tokens: - initial_prompt += f"{question['question']}\n{question['sql']}\n\n" - - return initial_prompt - - def get_sql_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs): - initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" - - initial_prompt = Mistral.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) - - initial_prompt = Mistral.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) - - message_log = [Mistral.system_message(initial_prompt)] - - for example in question_sql_list: - if example is None: - print("example is None") - else: - if example is not None and "question" in example and "sql" in example: - message_log.append(Mistral.user_message(example["question"])) - message_log.append(Mistral.assistant_message(example["sql"])) - - message_log.append(ChatMessage(role="user", content=question)) - - return message_log - def generate_sql(self, question: str, **kwargs) -> str: # Use the super generate_sql sql = super().generate_sql(question, **kwargs) diff --git a/src/vanna/ollama/__init__.py b/src/vanna/ollama/__init__.py new file mode 100644 index 00000000..0f4f48e2 --- /dev/null +++ b/src/vanna/ollama/__init__.py @@ -0,0 +1,50 @@ +from ..base import VannaBase +import requests +import json + +class Ollama(VannaBase): + def __init__(self, config=None): + if config is None or 'ollama_host' not in config: + self.host = "http://localhost:11434" + else: + self.host = config['ollama_host'] + + if config is None or 'model' not in config: + raise ValueError("config must contain a Ollama model") + else: + self.model = config['model'] + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def generate_sql(self, question: str, **kwargs) -> str: + # Use the super generate_sql + sql = super().generate_sql(question, **kwargs) + + # Replace "\_" with "_" + sql = sql.replace("\\_", "_") + + return sql + + def submit_prompt(self, prompt, **kwargs) -> str: + url = f"{self.host}/api/chat" + data = { + "model": self.model, + "stream": False, + "messages": prompt, + } + + response = requests.post(url, json=data) + + response_dict = response.json() + + self.log(response.text) + + return response_dict['message']['content'] + diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 3125c36b..febb089d 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -38,198 +38,15 @@ def __init__(self, client=None, config=None): if "api_key" in config: self.client = OpenAI(api_key=config["api_key"]) - @staticmethod - def system_message(message: str) -> dict: + def system_message(self, message: str) -> any: return {"role": "system", "content": message} - @staticmethod - def user_message(message: str) -> dict: + def user_message(self, message: str) -> any: return {"role": "user", "content": message} - @staticmethod - def assistant_message(message: str) -> dict: + def assistant_message(self, message: str) -> any: return {"role": "assistant", "content": message} - @staticmethod - def str_to_approx_token_count(string: str) -> int: - return len(string) / 4 - - @staticmethod - def add_ddl_to_prompt( - initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000 - ) -> str: - if len(ddl_list) > 0: - initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for ddl in ddl_list: - if ( - OpenAI_Chat.str_to_approx_token_count(initial_prompt) - + OpenAI_Chat.str_to_approx_token_count(ddl) - < max_tokens - ): - initial_prompt += f"{ddl}\n\n" - - return initial_prompt - - @staticmethod - def add_documentation_to_prompt( - initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000 - ) -> str: - if len(documentation_list) > 0: - initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for documentation in documentation_list: - if ( - OpenAI_Chat.str_to_approx_token_count(initial_prompt) - + OpenAI_Chat.str_to_approx_token_count(documentation) - < max_tokens - ): - initial_prompt += f"{documentation}\n\n" - - return initial_prompt - - @staticmethod - def add_sql_to_prompt( - initial_prompt: str, sql_list: list[str], max_tokens: int = 14000 - ) -> str: - if len(sql_list) > 0: - initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for question in sql_list: - if ( - OpenAI_Chat.str_to_approx_token_count(initial_prompt) - + OpenAI_Chat.str_to_approx_token_count(question["sql"]) - < max_tokens - ): - initial_prompt += f"{question['question']}\n{question['sql']}\n\n" - - return initial_prompt - - def get_sql_prompt( - self, - question: str, - question_sql_list: list, - ddl_list: list, - doc_list: list, - **kwargs, - ): - initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" - - initial_prompt = OpenAI_Chat.add_ddl_to_prompt( - initial_prompt, ddl_list, max_tokens=14000 - ) - - initial_prompt = OpenAI_Chat.add_documentation_to_prompt( - initial_prompt, doc_list, max_tokens=14000 - ) - - message_log = [OpenAI_Chat.system_message(initial_prompt)] - - for example in question_sql_list: - if example is None: - print("example is None") - else: - if example is not None and "question" in example and "sql" in example: - message_log.append(OpenAI_Chat.user_message(example["question"])) - message_log.append(OpenAI_Chat.assistant_message(example["sql"])) - - message_log.append({"role": "user", "content": question}) - - return message_log - - def get_followup_questions_prompt( - self, - question: str, - df: pd.DataFrame, - question_sql_list: list, - ddl_list: list, - doc_list: list, - **kwargs, - ): - initial_prompt = f"The user initially asked the question: '{question}': \n\n" - - initial_prompt = OpenAI_Chat.add_ddl_to_prompt( - initial_prompt, ddl_list, max_tokens=14000 - ) - - initial_prompt = OpenAI_Chat.add_documentation_to_prompt( - initial_prompt, doc_list, max_tokens=14000 - ) - - initial_prompt = OpenAI_Chat.add_sql_to_prompt( - initial_prompt, question_sql_list, max_tokens=14000 - ) - - message_log = [OpenAI_Chat.system_message(initial_prompt)] - message_log.append( - OpenAI_Chat.user_message( - "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions." - ) - ) - - return message_log - - def generate_question(self, sql: str, **kwargs) -> str: - response = self.submit_prompt( - [ - self.system_message( - "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." - ), - self.user_message(sql), - ], - **kwargs, - ) - - return response - - def _extract_python_code(self, markdown_string: str) -> str: - # Regex pattern to match Python code blocks - pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" - - # Find all matches in the markdown string - matches = re.findall(pattern, markdown_string, re.IGNORECASE) - - # Extract the Python code from the matches - python_code = [] - for match in matches: - python = match[0] if match[0] else match[1] - python_code.append(python.strip()) - - if len(python_code) == 0: - return markdown_string - - return python_code[0] - - def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: - # Remove the fig.show() statement from the plotly code - plotly_code = raw_plotly_code.replace("fig.show()", "") - - return plotly_code - - def generate_plotly_code( - self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs - ) -> str: - if question is not None: - system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" - else: - system_msg = "The following is a pandas DataFrame " - - if sql is not None: - system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" - - system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" - - message_log = [ - self.system_message(system_msg), - self.user_message( - "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." - ), - ] - - plotly_code = self.submit_prompt(message_log, kwargs=kwargs) - - return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) - def submit_prompt(self, prompt, **kwargs) -> str: if prompt is None: raise Exception("Prompt is None")