From b4aea353b27e0293b669b5e19a75f17709c094d3 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Fri, 26 Jan 2024 23:05:36 -0500 Subject: [PATCH 1/8] ollama --- src/vanna/base/base.py | 252 +++++++++++++++++++++++--------- src/vanna/mistral/mistral.py | 140 +----------------- src/vanna/ollama/__init__.py | 50 +++++++ src/vanna/ollama/ollama.py | 0 src/vanna/openai/openai_chat.py | 189 +----------------------- 5 files changed, 242 insertions(+), 389 deletions(-) create mode 100644 src/vanna/ollama/__init__.py create mode 100644 src/vanna/ollama/ollama.py diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index ee1754c0..d2a6091a 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -37,7 +37,9 @@ def generate_sql(self, question: str, **kwargs) -> str: doc_list=doc_list, **kwargs, ) + self.log(prompt) llm_response = self.submit_prompt(prompt, **kwargs) + self.log(llm_response) return self.extract_sql(llm_response) def extract_sql(self, llm_response: str) -> str: @@ -133,6 +135,68 @@ def remove_training_data(id: str, **kwargs) -> bool: # ----------------- Use Any Language Model API ----------------- # @abstractmethod + def system_message(self, message: str) -> any: + pass + + @abstractmethod + def user_message(self, message: str) -> any: + pass + + @abstractmethod + def assistant_message(self, message: str) -> any: + pass + + def str_to_approx_token_count(self, string: str) -> int: + return len(string) / 4 + + def add_ddl_to_prompt( + self, initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000 + ) -> str: + if len(ddl_list) > 0: + initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for ddl in ddl_list: + if ( + self.str_to_approx_token_count(initial_prompt) + + self.str_to_approx_token_count(ddl) + < max_tokens + ): + initial_prompt += f"{ddl}\n\n" + + return initial_prompt + + def add_documentation_to_prompt( + self, initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000 + ) -> str: + if len(documentation_list) > 0: + initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for documentation in documentation_list: + if ( + self.str_to_approx_token_count(initial_prompt) + + self.str_to_approx_token_count(documentation) + < max_tokens + ): + initial_prompt += f"{documentation}\n\n" + + return initial_prompt + + def add_sql_to_prompt( + self, initial_prompt: str, sql_list: list[str], max_tokens: int = 14000 + ) -> str: + if len(sql_list) > 0: + initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" + + for question in sql_list: + if ( + self.str_to_approx_token_count(initial_prompt) + + self.str_to_approx_token_count(question["sql"]) + < max_tokens + ): + initial_prompt += f"{question['question']}\n{question['sql']}\n\n" + + return initial_prompt + def get_sql_prompt( self, question: str, @@ -141,9 +205,30 @@ def get_sql_prompt( doc_list: list, **kwargs, ): - pass + initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" + + initial_prompt = self.add_ddl_to_prompt( + initial_prompt, ddl_list, max_tokens=14000 + ) + + initial_prompt = self.add_documentation_to_prompt( + initial_prompt, doc_list, max_tokens=14000 + ) + + message_log = [self.system_message(initial_prompt)] + + for example in question_sql_list: + if example is None: + print("example is None") + else: + if example is not None and "question" in example and "sql" in example: + message_log.append(self.user_message(example["question"])) + message_log.append(self.assistant_message(example["sql"])) + + message_log.append(self.user_message(question)) + + return message_log - @abstractmethod def get_followup_questions_prompt( self, question: str, @@ -151,22 +236,94 @@ def get_followup_questions_prompt( ddl_list: list, doc_list: list, **kwargs, - ): - pass + ) -> list: + initial_prompt = f"The user initially asked the question: '{question}': \n\n" + + initial_prompt = self.add_ddl_to_prompt( + initial_prompt, ddl_list, max_tokens=14000 + ) + + initial_prompt = self.add_documentation_to_prompt( + initial_prompt, doc_list, max_tokens=14000 + ) + + initial_prompt = self.add_sql_to_prompt( + initial_prompt, question_sql_list, max_tokens=14000 + ) + + message_log = [self.system_message(initial_prompt)] + message_log.append( + self.user_message( + "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions." + ) + ) + + return message_log @abstractmethod def submit_prompt(self, prompt, **kwargs) -> str: pass - @abstractmethod def generate_question(self, sql: str, **kwargs) -> str: - pass + response = self.submit_prompt( + [ + self.system_message( + "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." + ), + self.user_message(sql), + ], + **kwargs, + ) + + return response + + def _extract_python_code(self, markdown_string: str) -> str: + # Regex pattern to match Python code blocks + pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" + + # Find all matches in the markdown string + matches = re.findall(pattern, markdown_string, re.IGNORECASE) + + # Extract the Python code from the matches + python_code = [] + for match in matches: + python = match[0] if match[0] else match[1] + python_code.append(python.strip()) + + if len(python_code) == 0: + return markdown_string + + return python_code[0] + + def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: + # Remove the fig.show() statement from the plotly code + plotly_code = raw_plotly_code.replace("fig.show()", "") + + return plotly_code - @abstractmethod def generate_plotly_code( self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs ) -> str: - pass + if question is not None: + system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" + else: + system_msg = "The following is a pandas DataFrame " + + if sql is not None: + system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" + + system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" + + message_log = [ + self.system_message(system_msg), + self.user_message( + "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." + ), + ] + + plotly_code = self.submit_prompt(message_log, kwargs=kwargs) + + return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) # ----------------- Connect to Any Database to run the Generated SQL ----------------- # @@ -469,6 +626,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: self.run_sql_is_set = True self.run_sql = run_sql_bigquery + def connect_to_duckdb(self, url: str, init_sql: str = None): """ Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] @@ -514,9 +672,10 @@ def run_sql_duckdb(sql: str): self.run_sql = run_sql_duckdb self.run_sql_is_set = True + def run_sql(sql: str, **kwargs) -> pd.DataFrame: raise NotImplementedError( - "You need to connect_to_snowflake or other database first." + "You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_sqlite(), vn.connect_to_postgres(), vn.connect_to_bigquery(), or vn.connect_to_duckdb()" ) def ask( @@ -533,6 +692,24 @@ def ask( ], None, ]: + """ + **Example:** + ```python + vn.ask("What are the top 10 customers by sales?") + ``` + + Ask Vanna.AI a question and get the SQL query that answers it. + + Args: + question (str): The question to ask. + print_results (bool): Whether to print the results of the SQL query. + auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query. + visualize (bool): Whether to generate plotly code and display the plotly figure. + + Returns: + Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure. + """ + if question is None: question = input("Enter a question: ") @@ -927,60 +1104,3 @@ def get_plotly_figure( return fig - -class SplitStorage(VannaBase): - def __init__(self, config=None): - VannaBase.__init__(self, config=config) - - def get_similar_question_sql(self, embedding: str, **kwargs) -> list: - question_sql_ids = self.get_similar_question_sql_ids(embedding, **kwargs) - question_sql_list = self.get_question_sql(question_sql_ids, **kwargs) - return question_sql_list - - def get_related_ddl(self, embedding: str, **kwargs) -> list: - ddl_ids = self.get_related_ddl_ids(embedding, **kwargs) - ddl_list = self.get_ddl(ddl_ids, **kwargs) - return ddl_list - - def get_related_documentation(self, embedding: str, **kwargs) -> list: - doc_ids = self.get_related_documentation_ids(embedding, **kwargs) - doc_list = self.get_documentation(doc_ids, **kwargs) - return doc_list - - # ----------------- Use Any Vector Database to Store and Lookup Embedding Similarity ----------------- # - @abstractmethod - def store_question_sql_embedding(self, embedding: str, **kwargs) -> str: - pass - - @abstractmethod - def store_ddl_embedding(self, embedding: str, **kwargs) -> str: - pass - - @abstractmethod - def store_documentation_embedding(self, embedding: str, **kwargs) -> str: - pass - - @abstractmethod - def get_similar_question_sql_ids(self, embedding: str, **kwargs) -> list: - pass - - @abstractmethod - def get_related_ddl_ids(self, embedding: str, **kwargs) -> list: - pass - - @abstractmethod - def get_related_documentation_ids(self, embedding: str, **kwargs) -> list: - pass - - # ----------------- Use Database to Retrieve the Documents from ID Lists ----------------- # - @abstractmethod - def get_question_sql(self, question_sql_ids: list, **kwargs) -> list: - pass - - @abstractmethod - def get_documentation(self, doc_ids: list, **kwargs) -> list: - pass - - @abstractmethod - def get_ddl(self, ddl_ids: list, **kwargs) -> list: - pass diff --git a/src/vanna/mistral/mistral.py b/src/vanna/mistral/mistral.py index 8acccfcd..627c46e7 100644 --- a/src/vanna/mistral/mistral.py +++ b/src/vanna/mistral/mistral.py @@ -19,149 +19,15 @@ def __init__(self, config=None): self.client = MistralClient(api_key=api_key) self.model = model - def _extract_python_code(self, markdown_string: str) -> str: - # Regex pattern to match Python code blocks - pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" - - # Find all matches in the markdown string - matches = re.findall(pattern, markdown_string, re.IGNORECASE) - - # Extract the Python code from the matches - python_code = [] - for match in matches: - python = match[0] if match[0] else match[1] - python_code.append(python.strip()) - - if len(python_code) == 0: - return markdown_string - - return python_code[0] - - def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: - # Remove the fig.show() statement from the plotly code - plotly_code = raw_plotly_code.replace("fig.show()", "") - - return plotly_code - - def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str: - if question is not None: - system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" - else: - system_msg = "The following is a pandas DataFrame " - - if sql is not None: - system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" - - system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" - - message_log = [ - self.system_message(system_msg), - self.user_message( - "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." - ), - ] - - plotly_code = self.submit_prompt(message_log, kwargs=kwargs) - - return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) - - def generate_question(self, sql: str, **kwargs) -> str: - response = self.submit_prompt( - [ - self.system_message( - "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." - ), - self.user_message(sql), - ], - **kwargs, - ) - - return response - - def get_followup_questions_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs): - initial_prompt = f"The user initially asked the question: '{question}': \n\n" - - initial_prompt = Mistral.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) - - initial_prompt = Mistral.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) - - initial_prompt = Mistral.add_sql_to_prompt(initial_prompt, question_sql_list, max_tokens=14000) - - message_log = [Mistral.system_message(initial_prompt)] - message_log.append(Mistral.user_message("Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions.")) - - return message_log - - @staticmethod - def system_message(message: str) -> dict: + def system_message(self, message: str) -> any: return ChatMessage(role="system", content=message) - @staticmethod - def user_message(message: str) -> dict: + def user_message(self, message: str) -> any: return ChatMessage(role="user", content=message) - @staticmethod - def assistant_message(message: str) -> dict: + def assistant_message(self, message: str) -> any: return ChatMessage(role="assistant", content=message) - @staticmethod - def str_to_approx_token_count(string: str) -> int: - return len(string) / 4 - - @staticmethod - def add_ddl_to_prompt(initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000) -> str: - if len(ddl_list) > 0: - initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for ddl in ddl_list: - if Mistral.str_to_approx_token_count(initial_prompt) + Mistral.str_to_approx_token_count(ddl) < max_tokens: - initial_prompt += f"{ddl}\n\n" - - return initial_prompt - - @staticmethod - def add_documentation_to_prompt(initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000) -> str: - if len(documentation_list) > 0: - initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for documentation in documentation_list: - if Mistral.str_to_approx_token_count(initial_prompt) + Mistral.str_to_approx_token_count(documentation) < max_tokens: - initial_prompt += f"{documentation}\n\n" - - return initial_prompt - - @staticmethod - def add_sql_to_prompt(initial_prompt: str, sql_list: list[str], max_tokens: int = 14000) -> str: - if len(sql_list) > 0: - initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for question in sql_list: - if Mistral.str_to_approx_token_count(initial_prompt) + Mistral.str_to_approx_token_count(question["sql"]) < max_tokens: - initial_prompt += f"{question['question']}\n{question['sql']}\n\n" - - return initial_prompt - - def get_sql_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs): - initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" - - initial_prompt = Mistral.add_ddl_to_prompt(initial_prompt, ddl_list, max_tokens=14000) - - initial_prompt = Mistral.add_documentation_to_prompt(initial_prompt, doc_list, max_tokens=14000) - - message_log = [Mistral.system_message(initial_prompt)] - - for example in question_sql_list: - if example is None: - print("example is None") - else: - if example is not None and "question" in example and "sql" in example: - message_log.append(Mistral.user_message(example["question"])) - message_log.append(Mistral.assistant_message(example["sql"])) - - message_log.append(ChatMessage(role="user", content=question)) - - return message_log - def generate_sql(self, question: str, **kwargs) -> str: # Use the super generate_sql sql = super().generate_sql(question, **kwargs) diff --git a/src/vanna/ollama/__init__.py b/src/vanna/ollama/__init__.py new file mode 100644 index 00000000..0f4f48e2 --- /dev/null +++ b/src/vanna/ollama/__init__.py @@ -0,0 +1,50 @@ +from ..base import VannaBase +import requests +import json + +class Ollama(VannaBase): + def __init__(self, config=None): + if config is None or 'ollama_host' not in config: + self.host = "http://localhost:11434" + else: + self.host = config['ollama_host'] + + if config is None or 'model' not in config: + raise ValueError("config must contain a Ollama model") + else: + self.model = config['model'] + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def generate_sql(self, question: str, **kwargs) -> str: + # Use the super generate_sql + sql = super().generate_sql(question, **kwargs) + + # Replace "\_" with "_" + sql = sql.replace("\\_", "_") + + return sql + + def submit_prompt(self, prompt, **kwargs) -> str: + url = f"{self.host}/api/chat" + data = { + "model": self.model, + "stream": False, + "messages": prompt, + } + + response = requests.post(url, json=data) + + response_dict = response.json() + + self.log(response.text) + + return response_dict['message']['content'] + diff --git a/src/vanna/ollama/ollama.py b/src/vanna/ollama/ollama.py new file mode 100644 index 00000000..e69de29b diff --git a/src/vanna/openai/openai_chat.py b/src/vanna/openai/openai_chat.py index 3125c36b..febb089d 100644 --- a/src/vanna/openai/openai_chat.py +++ b/src/vanna/openai/openai_chat.py @@ -38,198 +38,15 @@ def __init__(self, client=None, config=None): if "api_key" in config: self.client = OpenAI(api_key=config["api_key"]) - @staticmethod - def system_message(message: str) -> dict: + def system_message(self, message: str) -> any: return {"role": "system", "content": message} - @staticmethod - def user_message(message: str) -> dict: + def user_message(self, message: str) -> any: return {"role": "user", "content": message} - @staticmethod - def assistant_message(message: str) -> dict: + def assistant_message(self, message: str) -> any: return {"role": "assistant", "content": message} - @staticmethod - def str_to_approx_token_count(string: str) -> int: - return len(string) / 4 - - @staticmethod - def add_ddl_to_prompt( - initial_prompt: str, ddl_list: list[str], max_tokens: int = 14000 - ) -> str: - if len(ddl_list) > 0: - initial_prompt += f"\nYou may use the following DDL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for ddl in ddl_list: - if ( - OpenAI_Chat.str_to_approx_token_count(initial_prompt) - + OpenAI_Chat.str_to_approx_token_count(ddl) - < max_tokens - ): - initial_prompt += f"{ddl}\n\n" - - return initial_prompt - - @staticmethod - def add_documentation_to_prompt( - initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000 - ) -> str: - if len(documentation_list) > 0: - initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for documentation in documentation_list: - if ( - OpenAI_Chat.str_to_approx_token_count(initial_prompt) - + OpenAI_Chat.str_to_approx_token_count(documentation) - < max_tokens - ): - initial_prompt += f"{documentation}\n\n" - - return initial_prompt - - @staticmethod - def add_sql_to_prompt( - initial_prompt: str, sql_list: list[str], max_tokens: int = 14000 - ) -> str: - if len(sql_list) > 0: - initial_prompt += f"\nYou may use the following SQL statements as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n" - - for question in sql_list: - if ( - OpenAI_Chat.str_to_approx_token_count(initial_prompt) - + OpenAI_Chat.str_to_approx_token_count(question["sql"]) - < max_tokens - ): - initial_prompt += f"{question['question']}\n{question['sql']}\n\n" - - return initial_prompt - - def get_sql_prompt( - self, - question: str, - question_sql_list: list, - ddl_list: list, - doc_list: list, - **kwargs, - ): - initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" - - initial_prompt = OpenAI_Chat.add_ddl_to_prompt( - initial_prompt, ddl_list, max_tokens=14000 - ) - - initial_prompt = OpenAI_Chat.add_documentation_to_prompt( - initial_prompt, doc_list, max_tokens=14000 - ) - - message_log = [OpenAI_Chat.system_message(initial_prompt)] - - for example in question_sql_list: - if example is None: - print("example is None") - else: - if example is not None and "question" in example and "sql" in example: - message_log.append(OpenAI_Chat.user_message(example["question"])) - message_log.append(OpenAI_Chat.assistant_message(example["sql"])) - - message_log.append({"role": "user", "content": question}) - - return message_log - - def get_followup_questions_prompt( - self, - question: str, - df: pd.DataFrame, - question_sql_list: list, - ddl_list: list, - doc_list: list, - **kwargs, - ): - initial_prompt = f"The user initially asked the question: '{question}': \n\n" - - initial_prompt = OpenAI_Chat.add_ddl_to_prompt( - initial_prompt, ddl_list, max_tokens=14000 - ) - - initial_prompt = OpenAI_Chat.add_documentation_to_prompt( - initial_prompt, doc_list, max_tokens=14000 - ) - - initial_prompt = OpenAI_Chat.add_sql_to_prompt( - initial_prompt, question_sql_list, max_tokens=14000 - ) - - message_log = [OpenAI_Chat.system_message(initial_prompt)] - message_log.append( - OpenAI_Chat.user_message( - "Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions." - ) - ) - - return message_log - - def generate_question(self, sql: str, **kwargs) -> str: - response = self.submit_prompt( - [ - self.system_message( - "The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question." - ), - self.user_message(sql), - ], - **kwargs, - ) - - return response - - def _extract_python_code(self, markdown_string: str) -> str: - # Regex pattern to match Python code blocks - pattern = r"```[\w\s]*python\n([\s\S]*?)```|```([\s\S]*?)```" - - # Find all matches in the markdown string - matches = re.findall(pattern, markdown_string, re.IGNORECASE) - - # Extract the Python code from the matches - python_code = [] - for match in matches: - python = match[0] if match[0] else match[1] - python_code.append(python.strip()) - - if len(python_code) == 0: - return markdown_string - - return python_code[0] - - def _sanitize_plotly_code(self, raw_plotly_code: str) -> str: - # Remove the fig.show() statement from the plotly code - plotly_code = raw_plotly_code.replace("fig.show()", "") - - return plotly_code - - def generate_plotly_code( - self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs - ) -> str: - if question is not None: - system_msg = f"The following is a pandas DataFrame that contains the results of the query that answers the question the user asked: '{question}'" - else: - system_msg = "The following is a pandas DataFrame " - - if sql is not None: - system_msg += f"\n\nThe DataFrame was produced using this query: {sql}\n\n" - - system_msg += f"The following is information about the resulting pandas DataFrame 'df': \n{df_metadata}" - - message_log = [ - self.system_message(system_msg), - self.user_message( - "Can you generate the Python plotly code to chart the results of the dataframe? Assume the data is in a pandas dataframe called 'df'. If there is only one value in the dataframe, use an Indicator. Respond with only Python code. Do not answer with any explanations -- just the code." - ), - ] - - plotly_code = self.submit_prompt(message_log, kwargs=kwargs) - - return self._sanitize_plotly_code(self._extract_python_code(plotly_code)) - def submit_prompt(self, prompt, **kwargs) -> str: if prompt is None: raise Exception("Prompt is None") From e995493bcd189f3052c99ea8295c789a6de1aeea Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Sat, 27 Jan 2024 10:00:49 -0500 Subject: [PATCH 2/8] remove 25 piece training data restriction --- src/vanna/flask.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vanna/flask.py b/src/vanna/flask.py index 57bec28f..5177ceaf 100644 --- a/src/vanna/flask.py +++ b/src/vanna/flask.py @@ -251,7 +251,7 @@ def get_training_data(): { "type": "df", "id": "training_data", - "df": df.tail(25).to_json(orient="records"), + "df": df.to_json(orient="records"), } ) From 6c6a09cbab63a3b3b9b868965b59cfd0bb56005d Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:19:40 -0500 Subject: [PATCH 3/8] new tests --- .gitattributes | 1 + .gitignore | 7 +- pyproject.toml | 4 +- src/vanna/__init__.py | 1926 ++------------------------ src/vanna/base/base.py | 244 +++- src/vanna/flask.py | 8 + src/vanna/ollama/__init__.py | 27 +- src/vanna/remote.py | 9 + tests/fixtures/questions.json | 10 - tests/fixtures/sql/testSqlCreate.sql | 1 - tests/fixtures/sql/testSqlSelect.sql | 1 - tests/test_vanna.py | 533 +------ tox.ini | 10 + 13 files changed, 412 insertions(+), 2369 deletions(-) create mode 100644 .gitattributes delete mode 100644 tests/fixtures/questions.json delete mode 100644 tests/fixtures/sql/testSqlCreate.sql delete mode 100644 tests/fixtures/sql/testSqlSelect.sql diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..a894e29e --- /dev/null +++ b/.gitattributes @@ -0,0 +1 @@ +*.ipynb linguist-detectable=false diff --git a/.gitignore b/.gitignore index c04a3e74..d7285afd 100644 --- a/.gitignore +++ b/.gitignore @@ -2,8 +2,7 @@ build **.egg-info venv .DS_Store -notebooks/.ipynb_checkpoints -notebooks/test*.ipynb +notebooks/* tests/__pycache__ __pycache__/ .idea @@ -12,4 +11,6 @@ docs/*.html .ipynb_checkpoints/ .tox/ notebooks/chroma.sqlite3 -dist \ No newline at end of file +dist +.env +*.sqlite \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index df6687c3..8f2f2486 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "vanna" -version = "0.0.36" +version = "0.1.0" authors = [ { name="Zain Hoda", email="zain@vanna.ai" }, ] @@ -30,7 +30,7 @@ postgres = ["psycopg2-binary", "db-dtypes"] bigquery = ["google-cloud-bigquery"] snowflake = ["snowflake-connector-python"] duckdb = ["duckdb"] -all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python","duckdb"] +all = ["psycopg2-binary", "db-dtypes", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai"] test = ["tox"] chromadb = ["chromadb"] openai = ["openai"] diff --git a/src/vanna/__init__.py b/src/vanna/__init__.py index 55177031..2642b6dc 100644 --- a/src/vanna/__init__.py +++ b/src/vanna/__init__.py @@ -1,186 +1,23 @@ -r""" -# Source Code -View the source code on GitHub: [https://github.com/vanna-ai/vanna](https://github.com/vanna-ai/vanna) - -# Basic Usage - -## Getting an API key -```python -import vanna as vn -api_key = vn.get_api_key('my-email@example.com') -vn.set_api_key(api_key) -``` - -## Setting the model -```python -vn.set_model('chinook') -``` - -## Asking a question -```python -vn.ask(question='What are the top 10 artists by sales?') -``` -`vn.ask(...)` is a convenience wrapper around `vn.generate_sql(...)`, `vn.run_sql(...)`, `vn.generate_plotly_code(...)`, `vn.get_plotly_figure(...)`, and `vn.generate_followup_questions(...)`. - -For a runnable notebook where you can ask questions, see [here](/docs/getting-started.html) - -## Training -There are 3 main types of training data that you can add to a model: SQL, DDL, and documentation. -```python -# DDL Statements -vn.train(ddl='CREATE TABLE employees (id INT, name VARCHAR(255), salary INT)') - -# Documentation -vn.train(documentation='Our organization\'s definition of sales is the discount price of an item multiplied by the quantity sold.') - -# SQL -vn.train(sql='SELECT AVG(salary) FROM employees') -``` - -`vn.train(...)` is a convenience wrapper around `vn.add_sql(...)`, `vn.add_ddl(...)`, and `vn.add_documentation(...)`. - -For a runnable notebook where you can train a model, see [here](/docs/manual-train.html) - - -# Nomenclature - -| Prefix | Definition | Examples | -| --- | --- | --- | -| `vn.set_` | Sets the variable for the current session | [`vn.set_model(...)`][vanna.set_model]
[`vn.set_api_key(...)`][vanna.set_api_key] | -| `vn.get_` | Performs a read-only operation | [`vn.get_model()`][vanna.get_models] | -| `vn.add_` | Adds something to the model | [`vn.add_sql(...)`][vanna.add_sql]
[`vn.add_ddl(...)`][vanna.add_ddl] | -| `vn.generate_` | Generates something using AI based on the information in the model | [`vn.generate_sql(...)`][vanna.generate_sql]
[`vn.generate_explanation()`][vanna.generate_explanation] | -| `vn.run_` | Runs code (SQL or Plotly) | [`vn.run_sql`][vanna.run_sql] | -| `vn.remove_` | Removes something from the model | [`vn.remove_training_data`][vanna.remove_training_data] | -| `vn.update_` | Updates something in the model | [`vn.update_model_visibility(...)`][vanna.update_model_visibility] | -| `vn.connect_` | Connects to a database | [`vn.connect_to_snowflake(...)`][vanna.connect_to_snowflake] | - -# Permissions -By default when you create a model it is private. You can add members or admins to your model or make it public. - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
User RolePublic ModelPrivate Model
UseTrainUseTrain
Non-Member
Member
Admin
- -# Open-Source and Extending - -Vanna.AI is open-source and extensible. If you'd like to use Vanna without the servers, see an example [here](/docs/local.html). - -The following is an example of where various functions are implemented in the codebase when using the default "local" version of Vanna. `vanna.base.VannaBase` is the base class which provides a `vanna.base.VannaBase.ask` and `vanna.base.VannaBase.train` function. Those rely on abstract methods which are implemented in the subclasses `vanna.openai_chat.OpenAI_Chat` and `vanna.chromadb_vector.ChromaDB_VectorStore`. `vanna.openai_chat.OpenAI_Chat` uses the OpenAI API to generate SQL and Plotly code. `vanna.chromadb_vector.ChromaDB_VectorStore` uses ChromaDB to store training data and generate embeddings. - -If you want to use Vanna with other LLMs or databases, you can create your own subclass of `vanna.base.VannaBase` and implement the abstract methods. - -```mermaid -flowchart - subgraph VannaBase - ask - train - end - - subgraph OpenAI_Chat - get_sql_prompt - submit_prompt - generate_question - generate_plotly_code - end - - subgraph ChromaDB_VectorStore - generate_embedding - add_question_sql - add_ddl - add_documentation - get_similar_question_sql - get_related_ddl - get_related_documentation - end -``` - - -# API Reference -""" - import dataclasses import json import os -import sqlite3 -import traceback -import warnings from dataclasses import dataclass from typing import Callable, List, Tuple, Union -from urllib.parse import urlparse import pandas as pd -import plotly -import plotly.express as px -import plotly.graph_objects as go import requests -import sqlparse +import plotly.graph_objs from .exceptions import ( - APIError, - ConnectionError, - DependencyError, - ImproperlyConfigured, OTPCodeError, - SQLRemoveError, ValidationError, ) from .types import ( - AccuracyStats, ApiKey, - DataFrameJSON, - DataResult, - Explanation, - FullQuestionDocument, - NewOrganization, - NewOrganizationMember, - Organization, - OrganizationList, - PlotlyResult, - Question, - QuestionCategory, - QuestionId, - QuestionList, - QuestionSQLPair, - QuestionStringList, - SQLAnswer, Status, - StringData, TrainingData, UserEmail, UserOTP, - Visibility, ) from .utils import sanitize_model_name, validate_config_path @@ -204,9 +41,19 @@ __org: Union[str, None] = None # Organization name for Vanna.AI -_endpoint = "https://ask.vanna.ai/rpc" _unauthenticated_endpoint = "https://ask.vanna.ai/unauthenticated_rpc" +def error_deprecation(): + raise Exception(""" +Please switch to the following method for initializing Vanna: + +from vanna.remote import VannaDefault + +api_key = # Your API key from https://vanna.ai/account/profile +vanna_model_name = # Your model name from https://vanna.ai/account/profile + +vn = VannaDefault(model=vanna_model_name, api_key=api_key) +""") def __unauthenticated_rpc_call(method, params): headers = { @@ -220,39 +67,6 @@ def __unauthenticated_rpc_call(method, params): return response.json() -def __rpc_call(method, params): - global api_key - global __org - - if api_key is None: - raise ImproperlyConfigured( - "API key not set. Use vn.get_api_key(...) to get an API key." - ) - - if __org is None and method != "list_orgs": - raise ImproperlyConfigured( - "model not set. Use vn.set_model(...) to set the model to use." - ) - - if method != "list_orgs": - headers = { - "Content-Type": "application/json", - "Vanna-Key": api_key, - "Vanna-Org": __org, - } - else: - headers = { - "Content-Type": "application/json", - "Vanna-Key": api_key, - "Vanna-Org": "demo-tpc-h", - } - - data = {"method": method, "params": [__dataclass_to_dict(obj) for obj in params]} - - response = requests.post(_endpoint, headers=headers, data=json.dumps(data)) - return response.json() - - def __dataclass_to_dict(obj): return dataclasses.asdict(obj) @@ -316,300 +130,41 @@ def get_api_key(email: str, otp_code: Union[str, None] = None) -> str: def set_api_key(key: str) -> None: - """ - Sets the API key for Vanna.AI. - - **Example:** - ```python - api_key = vn.get_api_key(email="my-email@example.com") - vn.set_api_key(api_key) - ``` - - Args: - key (str): The API key. - """ - global api_key - api_key = key - - models = get_models() - - if len(models) == 0: - raise ConnectionError( - "There was an error communicating with the Vanna.AI API. Please try again or contact support@vanna.ai" - ) + error_deprecation() def get_models() -> List[str]: - """ - **Example:** - ```python - models = vn.get_models() - ``` - - List the models that the user is a member of. - - Returns: - List[str]: A list of model names. - """ - d = __rpc_call(method="list_orgs", params=[]) - - if "result" not in d: - return [] - - orgs = OrganizationList(**d["result"]) - - return orgs.organizations + error_deprecation() def create_model(model: str, db_type: str) -> bool: - """ - **Example:** - ```python - vn.create_model(model="my-model", db_type="postgres") - ``` - - Create a new model. - - Args: - model (str): The name of the model to create. - db_type (str): The type of database to use for the model. This can be "Snowflake", "BigQuery", "Postgres", or anything else. - - Returns: - bool: True if the model was created successfully, False otherwise. - """ - global __org - if __org is None: - __org = "demo-tpc-h" - model = sanitize_model_name(model) - params = [NewOrganization(org_name=model, db_type=db_type)] - - d = __rpc_call(method="create_org", params=params) - - if "result" not in d: - return False - - status = Status(**d["result"]) - - if status.success: - __org = model - - return status.success + error_deprecation() def add_user_to_model(model: str, email: str, is_admin: bool) -> bool: - """ - **Example:** - ```python - vn.add_user_to_model(model="my-model", email="user@example.com") - ``` - - Add a user to an model. - - Args: - model (str): The name of the model to add the user to. - email (str): The email address of the user to add. - is_admin (bool): Whether or not the user should be an admin. - - Returns: - bool: True if the user was added successfully, False otherwise. - """ - - params = [NewOrganizationMember(org_name=model, email=email, is_admin=is_admin)] - - d = __rpc_call(method="add_user_to_org", params=params) - - if "result" not in d: - return False - - status = Status(**d["result"]) - - if not status.success: - print(status.message) - - return status.success + error_deprecation() def update_model_visibility(public: bool) -> bool: - """ - **Example:** - ```python - vn.update_model_visibility(public=True) - ``` - - Set the visibility of the current model. If a model is visible, anyone can see it. If it is not visible, only members of the model can see it. - - Args: - public (bool): Whether or not the model should be publicly visible. - - Returns: - bool: True if the model visibility was set successfully, False otherwise. - """ - params = [Visibility(visibility=public)] - - d = __rpc_call(method="set_org_visibility", params=params) - - if "result" not in d: - return False - - status = Status(**d["result"]) - - return status.success - - -def _set_org(org: str) -> None: - global __org - - my_orgs = get_models() - if org not in my_orgs: - # Check if org exists - d = __unauthenticated_rpc_call( - method="check_org_exists", - params=[Organization(name=org, user=None, connection=None)], - ) - - if "result" not in d: - raise ValidationError("Failed to check if model exists") - - status = Status(**d["result"]) - - if status.success: - raise ValidationError(f"An organization with the name {org} already exists") - - create = input(f"Would you like to create model '{org}'? (y/n): ") - - if create.lower() == "y": - db_type = input( - "What type of database would you like to use? (Snowflake, BigQuery, Postgres, etc.): " - ) - if create_model(model=org, db_type=db_type): - __org = org - else: - __org = None - raise ValidationError("Failed to create model") - else: - __org = org + error_deprecation() def set_model(model: str): - """ - Set the models to use for the Vanna.AI API. - - **Example:** - ```python - vn.set_model("my-model") - ``` - - Args: - model (str): The name of the model to use. - """ - if type(model) is not str: - raise ValidationError( - f"Please provide model name in string format and not {type(model)}." - ) - - if model == "my-model": - env_model = os.environ.get("VANNA_MODEL", None) - - if env_model is not None: - model = env_model - else: - raise ValidationError( - "Please replace 'my-model' with the name of your model" - ) - - _set_org(org=model) + error_deprecation() def add_sql( question: str, sql: str, tag: Union[str, None] = "Manually Trained" ) -> bool: - """ - Adds a question and its corresponding SQL query to the model's training data. The preferred way to call this is to use [`vn.train(sql=...)`][vanna.train]. - - **Example:** - ```python - vn.add_sql( - question="What is the average salary of employees?", - sql="SELECT AVG(salary) FROM employees" - ) - ``` - - Args: - question (str): The question to store. - sql (str): The SQL query to store. - tag (Union[str, None]): A tag to associate with the question and SQL query. - - Returns: - bool: True if the question and SQL query were stored successfully, False otherwise. - """ - params = [QuestionSQLPair(question=question, sql=sql, tag=tag)] - - d = __rpc_call(method="store_sql", params=params) - - if "result" not in d: - return False - - status = Status(**d["result"]) - - return status.success + error_deprecation() def add_ddl(ddl: str) -> bool: - """ - Adds a DDL statement to the model's training data - - **Example:** - ```python - vn.add_ddl( - ddl="CREATE TABLE employees (id INT, name VARCHAR(255), salary INT)" - ) - ``` - - Args: - ddl (str): The DDL statement to store. - - Returns: - bool: True if the DDL statement was stored successfully, False otherwise. - """ - params = [StringData(data=ddl)] - - d = __rpc_call(method="store_ddl", params=params) - - if "result" not in d: - return False - - status = Status(**d["result"]) - - return status.success + error_deprecation() def add_documentation(documentation: str) -> bool: - """ - Adds documentation to the model's training data - - **Example:** - ```python - vn.add_documentation( - documentation="Our organization's definition of sales is the discount price of an item multiplied by the quantity sold." - ) - ``` - - Args: - documentation (str): The documentation string to store. - - Returns: - bool: True if the documentation string was stored successfully, False otherwise. - """ - params = [StringData(data=documentation)] - - d = __rpc_call(method="store_documentation", params=params) - - if "result" not in d: - return False - - status = Status(**d["result"]) - - return status.success + error_deprecation() @dataclass @@ -693,145 +248,17 @@ def remove_item(self, item: str): break -def __get_databases() -> List[str]: - try: - df_databases = run_sql("SELECT * FROM INFORMATION_SCHEMA.DATABASES") - except: - try: - df_databases = run_sql("SHOW DATABASES") - except: - return [] - - return df_databases["DATABASE_NAME"].unique().tolist() - - -def __get_information_schema_tables(database: str) -> pd.DataFrame: - df_tables = run_sql(f"SELECT * FROM {database}.INFORMATION_SCHEMA.TABLES") - - return df_tables - - def get_training_plan_postgres( filter_databases: Union[List[str], None] = None, filter_schemas: Union[List[str], None] = None, include_information_schema: bool = False, use_historical_queries: bool = True, ) -> TrainingPlan: - plan = TrainingPlan([]) - - if run_sql is None: - raise ValidationError("Please connect to a database first.") - - df_columns = run_sql("select * from INFORMATION_SCHEMA.COLUMNS") - - databases = df_columns["table_catalog"].unique().tolist() - - for database in databases: - if filter_databases is not None and database not in filter_databases: - continue - - for schema in ( - df_columns.query(f'table_catalog == "{database}"')["table_schema"] - .unique() - .tolist() - ): - if filter_schemas is not None and schema not in filter_schemas: - continue - - if not include_information_schema and ( - schema == "information_schema" or schema == "pg_catalog" - ): - continue - - df_columns_filtered = df_columns.query( - f'table_catalog == "{database}" and table_schema == "{schema}"' - ) - - for table in df_columns_filtered["table_name"].unique().tolist(): - df_columns_filtered_to_table = df_columns_filtered.query( - f'table_name == "{table}"' - ) - doc = f"The following columns are in the {table} table in the {database} database:\n\n" - doc += df_columns_filtered_to_table[ - [ - "table_catalog", - "table_schema", - "table_name", - "column_name", - "data_type", - ] - ].to_markdown() - - plan._plan.append( - TrainingPlanItem( - item_type=TrainingPlanItem.ITEM_TYPE_IS, - item_group=f"{database}.{schema}", - item_name=table, - item_value=doc, - ) - ) - - return plan + error_deprecation() def get_training_plan_generic(df) -> TrainingPlan: - # For each of the following, we look at the df columns to see if there's a match: - database_column = df.columns[ - df.columns.str.lower().str.contains("database") - | df.columns.str.lower().str.contains("table_catalog") - ].to_list()[0] - schema_column = df.columns[ - df.columns.str.lower().str.contains("table_schema") - ].to_list()[0] - table_column = df.columns[ - df.columns.str.lower().str.contains("table_name") - ].to_list()[0] - column_column = df.columns[ - df.columns.str.lower().str.contains("column_name") - ].to_list()[0] - data_type_column = df.columns[ - df.columns.str.lower().str.contains("data_type") - ].to_list()[0] - - plan = TrainingPlan([]) - - for database in df[database_column].unique().tolist(): - for schema in ( - df.query(f'{database_column} == "{database}"')[schema_column] - .unique() - .tolist() - ): - for table in ( - df.query( - f'{database_column} == "{database}" and {schema_column} == "{schema}"' - )[table_column] - .unique() - .tolist() - ): - df_columns_filtered_to_table = df.query( - f'{database_column} == "{database}" and {schema_column} == "{schema}" and {table_column} == "{table}"' - ) - doc = f"The following columns are in the {table} table in the {database} database:\n\n" - doc += df_columns_filtered_to_table[ - [ - database_column, - schema_column, - table_column, - column_column, - data_type_column, - ] - ].to_markdown() - - plan._plan.append( - TrainingPlanItem( - item_type=TrainingPlanItem.ITEM_TYPE_IS, - item_group=f"{database}.{schema}", - item_name=table, - item_value=doc, - ) - ) - - return plan + error_deprecation() def get_training_plan_experimental( @@ -840,184 +267,7 @@ def get_training_plan_experimental( include_information_schema: bool = False, use_historical_queries: bool = True, ) -> TrainingPlan: - """ - **EXPERIMENTAL** : This method is experimental and may change in future versions. - - Get a training plan based on the metadata in the database. Currently this only works for Snowflake. - - **Example:** - ```python - plan = vn.get_training_plan_experimental(filter_databases=["employees"], filter_schemas=["public"]) - - vn.train(plan=plan) - ``` - """ - - plan = TrainingPlan([]) - - if run_sql is None: - raise ValidationError("Please connect to a database first.") - - if use_historical_queries: - try: - print("Trying query history") - df_history = run_sql( - """ select * from table(information_schema.query_history(result_limit => 5000)) order by start_time""" - ) - - df_history_filtered = df_history.query("ROWS_PRODUCED > 1") - if filter_databases is not None: - mask = ( - df_history_filtered["QUERY_TEXT"] - .str.lower() - .apply( - lambda x: any( - s in x for s in [s.lower() for s in filter_databases] - ) - ) - ) - df_history_filtered = df_history_filtered[mask] - - if filter_schemas is not None: - mask = ( - df_history_filtered["QUERY_TEXT"] - .str.lower() - .apply( - lambda x: any( - s in x for s in [s.lower() for s in filter_schemas] - ) - ) - ) - df_history_filtered = df_history_filtered[mask] - - for query in df_history_filtered.sample(10)["QUERY_TEXT"].unique().tolist(): - plan._plan.append( - TrainingPlanItem( - item_type=TrainingPlanItem.ITEM_TYPE_SQL, - item_group="", - item_name=generate_question(query), - item_value=query, - ) - ) - - except Exception as e: - print(e) - - databases = __get_databases() - - for database in databases: - if filter_databases is not None and database not in filter_databases: - continue - - try: - df_tables = __get_information_schema_tables(database=database) - - print(f"Trying INFORMATION_SCHEMA.COLUMNS for {database}") - df_columns = run_sql(f"SELECT * FROM {database}.INFORMATION_SCHEMA.COLUMNS") - - for schema in df_tables["TABLE_SCHEMA"].unique().tolist(): - if filter_schemas is not None and schema not in filter_schemas: - continue - - if not include_information_schema and schema == "INFORMATION_SCHEMA": - continue - - df_columns_filtered_to_schema = df_columns.query( - f"TABLE_SCHEMA == '{schema}'" - ) - - try: - tables = ( - df_columns_filtered_to_schema["TABLE_NAME"].unique().tolist() - ) - - for table in tables: - df_columns_filtered_to_table = ( - df_columns_filtered_to_schema.query( - f"TABLE_NAME == '{table}'" - ) - ) - doc = f"The following columns are in the {table} table in the {database} database:\n\n" - doc += df_columns_filtered_to_table[ - [ - "TABLE_CATALOG", - "TABLE_SCHEMA", - "TABLE_NAME", - "COLUMN_NAME", - "DATA_TYPE", - "COMMENT", - ] - ].to_markdown() - - plan._plan.append( - TrainingPlanItem( - item_type=TrainingPlanItem.ITEM_TYPE_IS, - item_group=f"{database}.{schema}", - item_name=table, - item_value=doc, - ) - ) - - except Exception as e: - print(e) - pass - except Exception as e: - print(e) - - # try: - # print("Trying SHOW TABLES") - # df_f = run_sql("SHOW TABLES") - - # for schema in df_f.schema_name.unique(): - # try: - # print(f"Trying GET_DDL for {schema}") - # ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") - - # plan._plan.append(TrainingPlanItem( - # item_type=TrainingPlanItem.ITEM_TYPE_DDL, - # item_group=schema, - # item_name="All Tables", - # item_value=ddl_df.iloc[0, 0] - # )) - # except: - # pass - # except: - # try: - # print("Trying INFORMATION_SCHEMA.TABLES") - # df = run_sql("SELECT * FROM INFORMATION_SCHEMA.TABLES") - - # breakpoint() - - # try: - # print("Trying SCHEMATA") - # df_schemata = run_sql("SELECT * FROM region-us.INFORMATION_SCHEMA.SCHEMATA") - - # for schema in df_schemata.schema_name.unique(): - # df = run_sql(f"SELECT * FROM {schema}.information_schema.tables") - - # for table in df.table_name.unique(): - # plan._plan.append(TrainingPlanItem( - # item_type=TrainingPlanItem.ITEM_TYPE_IS, - # item_group=schema, - # item_name=table, - # item_value=None - # )) - - # try: - # ddl_df = run_sql(f"SELECT GET_DDL('schema', '{schema}')") - - # plan._plan.append(TrainingPlanItem( - # item_type=TrainingPlanItem.ITEM_TYPE_DDL, - # item_group=schema, - # item_name=None, - # item_value=ddl_df.iloc[0, 0] - # )) - # except: - # pass - # except: - # pass - - return plan + error_deprecation() def train( @@ -1029,805 +279,96 @@ def train( sql_file: str = None, plan: TrainingPlan = None, ) -> bool: - """ - **Example:** - ```python - vn.train() - ``` - - Train Vanna.AI on a question and its corresponding SQL query. - If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database. - If you call it with the sql argument, it's equivalent to [`add_sql()`][vanna.add_sql]. - If you call it with the ddl argument, it's equivalent to [`add_ddl()`][vanna.add_ddl]. - If you call it with the documentation argument, it's equivalent to [`add_documentation()`][vanna.add_documentation]. - It can also accept a JSON file path or SQL file path to train on a batch of questions and SQL queries or a list of SQL queries respectively. - Additionally, you can pass a [`TrainingPlan`][vanna.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_experimental()`][vanna.get_training_plan_experimental]. - - Args: - question (str): The question to train on. - sql (str): The SQL query to train on. - sql_file (str): The SQL file path. - json_file (str): The JSON file path. - ddl (str): The DDL statement. - documentation (str): The documentation to train on. - plan (TrainingPlan): The training plan to train on. - """ - - if question and not sql: - example_question = "What is the average salary of employees?" - raise ValidationError( - f"Please also provide a SQL query \n Example Question: {example_question}\n Answer: {ask(question=example_question)}" - ) - - if documentation: - print("Adding documentation....") - return add_documentation(documentation) - - if sql: - if question is None: - question = generate_question(sql) - print("Question generated with sql:", question, "\nAdding SQL...") - return add_sql(question=question, sql=sql) - - if ddl: - print("Adding ddl:", ddl) - return add_ddl(ddl) - - if json_file: - validate_config_path(json_file) - with open(json_file, "r") as js_file: - data = json.load(js_file) - print("Adding Questions And SQLs using file:", json_file) - for question in data: - if not add_sql(question=question["question"], sql=question["answer"]): - print( - f"Not able to add sql for question: {question['question']} from {json_file}" - ) - return False - return True - - if sql_file: - validate_config_path(sql_file) - with open(sql_file, "r") as file: - sql_statements = sqlparse.split(file.read()) - for statement in sql_statements: - if "CREATE TABLE" in statement: - if add_ddl(statement): - print("ddl Added!") - return True - print("Not able to add DDL") - return False - else: - question = generate_question(sql=statement) - if add_sql(question=question, sql=statement): - print("SQL added!") - return True - print("Not able to add sql.") - return False - return False - - if plan: - for item in plan._plan: - if item.item_type == TrainingPlanItem.ITEM_TYPE_DDL: - if not add_ddl(item.item_value): - print(f"Not able to add ddl for {item.item_group}") - return False - elif item.item_type == TrainingPlanItem.ITEM_TYPE_IS: - if not add_documentation(item.item_value): - print( - f"Not able to add documentation for {item.item_group}.{item.item_name}" - ) - return False - elif item.item_type == TrainingPlanItem.ITEM_TYPE_SQL: - if not add_sql(question=item.item_name, sql=item.item_value): - print(f"Not able to add sql for {item.item_group}.{item.item_name}") - return False + error_deprecation() def flag_sql_for_review( question: str, sql: Union[str, None] = None, error_msg: Union[str, None] = None ) -> bool: - """ - **Example:** - ```python - vn.flag_sql_for_review(question="What is the average salary of employees?") - ``` - Flag a question and its corresponding SQL query for review. You can see the tag show up in [`vn.get_all_questions()`][vanna.get_all_questions] + error_deprecation() - Args: - question (str): The question to flag. - sql (str): The SQL query to flag. - error_msg (str): The error message to flag. - Returns: - bool: True if the question and SQL query were flagged successfully, False otherwise. - """ - params = [ - QuestionCategory( - question=question, - category=QuestionCategory.FLAGGED_FOR_REVIEW, - ) - ] +def remove_sql(question: str) -> bool: + error_deprecation() - d = __rpc_call(method="set_accuracy_category", params=params) - if "result" not in d: - return False +def remove_training_data(id: str) -> bool: + error_deprecation() - status = Status(**d["result"]) - return status.success +def generate_sql(question: str) -> str: + error_deprecation() -# def read_questions_from_github(url: str) -> List[QuestionSQLPair]: -# """ -# **Example:** -# ```python -# url = "https://raw.githubusercontent.com/vanna-ai/vanna-ai/main/data/questions.json" -# questions = vn.read_questions_from_github(url) -# ``` -# Read questions and SQL queries from a GitHub URL. +def get_related_training_data(question: str) -> TrainingData: + error_deprecation() -# Args: -# url (str): The URL to read from. -# Returns: -# List[QuestionSQLPair]: A list of [`QuestionSQLPair`][vanna.QuestionSQLPair] objects. -# """ -# response = requests.get(url) -# data = response.json() +def generate_meta(question: str) -> str: + error_deprecation() -# question_sql_pairs = [] -# for item in data: -# question = item.get('question') -# sql = item.get('sql') -# if question and sql: -# question_sql_pair = QuestionSQLPair(question=question, sql=sql) -# question_sql_pairs.append(question_sql_pair) -# return question_sql_pairs +def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]: + error_deprecation() -def remove_sql(question: str) -> bool: - """ - Remove a question and its corresponding SQL query from the model's training data +def generate_questions() -> List[str]: + error_deprecation() - **Example:** - ```python - vn.remove_sql(question="What is the average salary of employees?") - ``` - Args: - question (str): The question to remove. - """ - params = [Question(question=question)] +def ask( + question: Union[str, None] = None, + print_results: bool = True, + auto_train: bool = True, + generate_followups: bool = True, +) -> Union[ + Tuple[ + Union[str, None], + Union[pd.DataFrame, None], + Union[plotly.graph_objs.Figure, None], + Union[List[str], None], + ], + None, +]: + error_deprecation() - d = __rpc_call(method="remove_sql", params=params) +def generate_plotly_code( + question: Union[str, None], + sql: Union[str, None], + df: pd.DataFrame, + chart_instructions: Union[str, None] = None, +) -> str: + error_deprecation() - if "result" not in d: - raise Exception(f"Error removing SQL") - return False - status = Status(**d["result"]) +def get_plotly_figure( + plotly_code: str, df: pd.DataFrame, dark_mode: bool = True +) -> plotly.graph_objs.Figure: + error_deprecation() - if not status.success: - raise SQLRemoveError(f"Error removing SQL: {status.message}") - return status.success +def get_results(cs, default_database: str, sql: str) -> pd.DataFrame: + error_deprecation() -def remove_training_data(id: str) -> bool: - """ - Remove training data from the model - - **Example:** - ```python - vn.remove_training_data(id="1-ddl") - ``` - - Args: - id (str): The ID of the training data to remove. - """ - params = [StringData(data=id)] - - d = __rpc_call(method="remove_training_data", params=params) - - if "result" not in d: - raise APIError(f"Error removing training data") - - status = Status(**d["result"]) - - if not status.success: - raise APIError(f"Error removing training data: {status.message}") - - return status.success - - -def generate_sql(question: str) -> str: - """ - **Example:** - ```python - vn.generate_sql(question="What is the average salary of employees?") - # SELECT AVG(salary) FROM employees - ``` - - Generate an SQL query using the Vanna.AI API. - - Args: - question (str): The question to generate an SQL query for. - - Returns: - str or None: The SQL query, or None if an error occurred. - """ - params = [Question(question=question)] - - d = __rpc_call(method="generate_sql_from_question", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - sql_answer = SQLAnswer(**d["result"]) - - return sql_answer.sql - - -def get_related_training_data(question: str) -> TrainingData: - """ - **Example:** - ```python - training_data = vn.get_related_training_data(question="What is the average salary of employees?") - ``` - - Get the training data related to a question. - - Args: - question (str): The question to get related training data for. - - Returns: - TrainingData or None: The related training data, or None if an error occurred. - """ - params = [Question(question=question)] - - d = __rpc_call(method="get_related_training_data", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - training_data = TrainingData(**d["result"]) - - return training_data - - -def generate_meta(question: str) -> str: - """ - **Example:** - ```python - vn.generate_meta(question="What tables are in the database?") - # Information about the tables in the database - ``` - - Generate answers about the metadata of a database using the Vanna.AI API. - - Args: - question (str): The question to generate an answer for. - - Returns: - str or None: The answer, or None if an error occurred. - """ - params = [Question(question=question)] - - d = __rpc_call(method="generate_meta_from_question", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - string_data = StringData(**d["result"]) - - return string_data.data - - -def generate_followup_questions(question: str, df: pd.DataFrame) -> List[str]: - """ - **Example:** - ```python - vn.generate_followup_questions(question="What is the average salary of employees?", df=df) - # ['What is the average salary of employees in the Sales department?', 'What is the average salary of employees in the Engineering department?', ...] - ``` - - Generate follow-up questions using the Vanna.AI API. - - Args: - question (str): The question to generate follow-up questions for. - df (pd.DataFrame): The DataFrame to generate follow-up questions for. - - Returns: - List[str] or None: The follow-up questions, or None if an error occurred. - """ - params = [ - DataResult( - question=question, - sql=None, - table_markdown="", - error=None, - correction_attempts=0, - ) - ] - - d = __rpc_call(method="generate_followup_questions", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - question_string_list = QuestionStringList(**d["result"]) - - return question_string_list.questions - - -def generate_questions() -> List[str]: - """ - **Example:** - ```python - vn.generate_questions() - # ['What is the average salary of employees?', 'What is the total salary of employees?', ...] - ``` - - Generate questions using the Vanna.AI API. - - Returns: - List[str] or None: The questions, or None if an error occurred. - """ - d = __rpc_call(method="generate_questions", params=[]) - - if "result" not in d: - return None - - # Load the result into a dataclass - question_string_list = QuestionStringList(**d["result"]) - - return question_string_list.questions - - -def ask( - question: Union[str, None] = None, - print_results: bool = True, - auto_train: bool = True, - generate_followups: bool = True, -) -> Union[ - Tuple[ - Union[str, None], - Union[pd.DataFrame, None], - Union[plotly.graph_objs.Figure, None], - Union[List[str], None], - ], - None, -]: - """ - **Example:** - ```python - # RECOMMENDED IN A NOTEBOOK: - sql, df, fig, followup_questions = vn.ask() - - - sql, df, fig, followup_questions = vn.ask(question="What is the average salary of employees?") - # SELECT AVG(salary) FROM employees - ``` - - Ask a question using the Vanna.AI API. This generates an SQL query, runs it, and returns the results in a dataframe and a Plotly figure. - If you set print_results to True, the sql, dataframe, and figure will be output to the screen instead of returned. - - Args: - question (str): The question to ask. If None, you will be prompted to enter a question. - print_results (bool): Whether to print the SQL query and results. - auto_train (bool): Whether to automatically train the model if the SQL query is incorrect. - generate_followups (bool): Whether to generate follow-up questions. - - Returns: - str or None: The SQL query, or None if an error occurred. - pd.DataFrame or None: The results of the SQL query, or None if an error occurred. - plotly.graph_objs.Figure or None: The Plotly figure, or None if an error occurred. - List[str] or None: The follow-up questions, or None if an error occurred. - """ - - if question is None: - question = input("Enter a question: ") - - try: - sql = generate_sql(question=question) - except Exception as e: - print(e) - return None, None, None, None - - if print_results: - try: - Code = __import__("IPython.display", fromlist=["Code"]).Code - display(Code(sql)) - except Exception as e: - print(sql) - - if run_sql is None: - print("If you want to run the SQL query, provide a vn.run_sql function.") - - if print_results: - return None - else: - return sql, None, None, None - - try: - df = run_sql(sql) - - if print_results: - try: - display = __import__("IPython.display", fromlist=["display"]).display - display(df) - except Exception as e: - print(df) - - if len(df) > 0 and auto_train: - add_sql(question=question, sql=sql, tag=types.QuestionCategory.SQL_RAN) - - try: - if df is not None and len(df) > 1: - plotly_code = generate_plotly_code(question=question, sql=sql, df=df) - fig = get_plotly_figure(plotly_code=plotly_code, df=df, dark_mode=False) - if print_results: - try: - display = __import__( - "IPython.display", fromlist=["display"] - ).display - - global fig_as_img - if fig_as_img: - Image = __import__( - "IPython.display", fromlist=["Image"] - ).Image - img_bytes = fig.to_image(format="png", scale=2) - display(Image(img_bytes)) - else: - fig.show() - except Exception as e: - fig.show() - - if generate_followups: - followup_questions = generate_followup_questions( - question=question, df=df - ) - if ( - print_results - and followup_questions is not None - and len(followup_questions) > 0 - ): - md = "AI-generated follow-up questions:\n\n" - for followup_question in followup_questions: - md += f"* {followup_question}\n" - - try: - display = __import__( - "IPython.display", fromlist=["display"] - ).display - Markdown = __import__( - "IPython.display", fromlist=["Markdown"] - ).Markdown - display(Markdown(md)) - except Exception as e: - print(md) - - if print_results: - return None - else: - return sql, df, fig, followup_questions - - if print_results: - return None - else: - return sql, df, fig, None - - except Exception as e: - # Print stack trace - traceback.print_exc() - print("Couldn't run plotly code: ", e) - if print_results: - return None - else: - return sql, df, None, None - - except Exception as e: - print("Couldn't run sql: ", e) - if print_results: - return None - else: - return sql, None, None, None - - -def generate_plotly_code( - question: Union[str, None], - sql: Union[str, None], - df: pd.DataFrame, - chart_instructions: Union[str, None] = None, -) -> str: - """ - **Example:** - ```python - vn.generate_plotly_code( - question="What is the average salary of employees?", - sql="SELECT AVG(salary) FROM employees", - df=df - ) - # fig = px.bar(df, x="name", y="salary") - ``` - Generate Plotly code using the Vanna.AI API. - - Args: - question (str): The question to generate Plotly code for. - sql (str): The SQL query to generate Plotly code for. - df (pd.DataFrame): The dataframe to generate Plotly code for. - chart_instructions (str): Optional instructions for how to plot the chart. - - Returns: - str or None: The Plotly code, or None if an error occurred. - """ - if chart_instructions is not None: - if question is not None: - question = ( - question - + " -- When plotting, follow these instructions: " - + chart_instructions - ) - else: - question = "When plotting, follow these instructions: " + chart_instructions - - params = [ - DataResult( - question=question, - sql=sql, - table_markdown=str(df.dtypes), - error=None, - correction_attempts=0, - ) - ] - - d = __rpc_call(method="generate_plotly_code", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - plotly_code = PlotlyResult(**d["result"]) - - return plotly_code.plotly_code - - -def get_plotly_figure( - plotly_code: str, df: pd.DataFrame, dark_mode: bool = True -) -> plotly.graph_objs.Figure: - """ - **Example:** - ```python - fig = vn.get_plotly_figure( - plotly_code="fig = px.bar(df, x='name', y='salary')", - df=df - ) - fig.show() - ``` - Get a Plotly figure from a dataframe and Plotly code. - - Args: - df (pd.DataFrame): The dataframe to use. - plotly_code (str): The Plotly code to use. - - Returns: - plotly.graph_objs.Figure: The Plotly figure. - """ - ldict = {"df": df, "px": px, "go": go} - exec(plotly_code, globals(), ldict) - - fig = ldict.get("fig", None) - - if fig is None: - return None - - if dark_mode: - fig.update_layout(template="plotly_dark") - - return fig - - -def get_results(cs, default_database: str, sql: str) -> pd.DataFrame: - """ - DEPRECATED. Use `vn.run_sql` instead. - Run the SQL query and return the results as a pandas dataframe. This is just a helper function that does not use the Vanna.AI API. - - Args: - cs: Snowflake connection cursor. - default_database (str): The default database to use. - sql (str): The SQL query to execute. - - Returns: - pd.DataFrame: The results of the SQL query. - """ - print("`vn.get_results()` is deprecated. Use `vn.run_sql()` instead.") - warnings.warn("`vn.get_results()` is deprecated. Use `vn.run_sql()` instead.") - - cs.execute(f"USE DATABASE {default_database}") - - cur = cs.execute(sql) - - results = cur.fetchall() - - # Create a pandas dataframe from the results - df = pd.DataFrame(results, columns=[desc[0] for desc in cur.description]) - - return df - - -def generate_explanation(sql: str) -> str: - """ - - **Example:** - ```python - vn.generate_explanation(sql="SELECT * FROM students WHERE name = 'John Doe'") - # 'This query selects all columns from the students table where the name is John Doe.' - ``` - - Generate an explanation of an SQL query using the Vanna.AI API. - - Args: - sql (str): The SQL query to generate an explanation for. - - Returns: - str or None: The explanation, or None if an error occurred. - - """ - params = [ - SQLAnswer( - raw_answer="", - prefix="", - postfix="", - sql=sql, - ) - ] - - d = __rpc_call(method="generate_explanation", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - explanation = Explanation(**d["result"]) - - return explanation.explanation +def generate_explanation(sql: str) -> str: + error_deprecation() def generate_question(sql: str) -> str: - """ - - **Example:** - ```python - vn.generate_question(sql="SELECT * FROM students WHERE name = 'John Doe'") - # 'What is the name of the student?' - ``` - - Generate a question from an SQL query using the Vanna.AI API. - - Args: - sql (str): The SQL query to generate a question for. - - Returns: - str or None: The question, or None if an error occurred. - - """ - params = [ - SQLAnswer( - raw_answer="", - prefix="", - postfix="", - sql=sql, - ) - ] - - d = __rpc_call(method="generate_question", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - question = Question(**d["result"]) - - return question.question + error_deprecation() def get_all_questions() -> pd.DataFrame: - """ - Get a list of questions from the Vanna.AI API. - - **Example:** - ```python - questions = vn.get_all_questions() - ``` - - Returns: - pd.DataFrame or None: The list of questions, or None if an error occurred. - - """ - # params = [Question(question="")] - params = [] - - d = __rpc_call(method="get_all_questions", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - all_questions = DataFrameJSON(**d["result"]) - - df = pd.read_json(all_questions.data) - - return df + error_deprecation() def get_training_data() -> pd.DataFrame: - """ - Get the training data for the current model - - **Example:** - ```python - training_data = vn.get_training_data() - ``` - - Returns: - pd.DataFrame or None: The training data, or None if an error occurred. - - """ - # params = [Question(question="")] - params = [] - - d = __rpc_call(method="get_training_data", params=params) - - if "result" not in d: - return None - - # Load the result into a dataclass - training_data = DataFrameJSON(**d["result"]) - - df = pd.read_json(training_data.data) - - return df + error_deprecation() def connect_to_sqlite(url: str): - """ - Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] - - Args: - url (str): The URL of the database to connect to. - - Returns: - None - """ - - # URL of the database to download - - # Path to save the downloaded database - path = "tempdb.sqlite" - - # Download the database if it doesn't exist - if not os.path.exists(path): - response = requests.get(url) - response.raise_for_status() # Check that the request was successful - with open(path, "wb") as f: - f.write(response.content) - - # Connect to the database - conn = sqlite3.connect(path) - - def run_sql_sqlite(sql: str): - return pd.read_sql_query(sql, conn) - - global run_sql - run_sql = run_sql_sqlite + error_deprecation() def connect_to_snowflake( @@ -1838,99 +379,7 @@ def connect_to_snowflake( schema: Union[str, None] = None, role: Union[str, None] = None, ): - """ - Connect to Snowflake using the Snowflake connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] - - **Example:** - ```python - import snowflake.connector - - vn.connect_to_snowflake( - account="myaccount", - username="myusername", - password="mypassword", - database="mydatabase", - role="myrole", - ) - ``` - - Args: - account (str): The Snowflake account name. - username (str): The Snowflake username. - password (str): The Snowflake password. - database (str): The default database to use. - schema (Union[str, None], optional): The schema to use. Defaults to None. - role (Union[str, None], optional): The role to use. Defaults to None. - """ - - try: - snowflake = __import__("snowflake.connector") - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method, run command:" - " \npip install vanna[snowflake]" - ) - - if username == "my-username": - username_env = os.getenv("SNOWFLAKE_USERNAME") - - if username_env is not None: - username = username_env - else: - raise ImproperlyConfigured("Please set your Snowflake username.") - - if password == "my-password": - password_env = os.getenv("SNOWFLAKE_PASSWORD") - - if password_env is not None: - password = password_env - else: - raise ImproperlyConfigured("Please set your Snowflake password.") - - if account == "my-account": - account_env = os.getenv("SNOWFLAKE_ACCOUNT") - - if account_env is not None: - account = account_env - else: - raise ImproperlyConfigured("Please set your Snowflake account.") - - if database == "my-database": - database_env = os.getenv("SNOWFLAKE_DATABASE") - - if database_env is not None: - database = database_env - else: - raise ImproperlyConfigured("Please set your Snowflake database.") - - conn = snowflake.connector.connect( - user=username, - password=password, - account=account, - database=database, - ) - - def run_sql_snowflake(sql: str) -> pd.DataFrame: - cs = conn.cursor() - - if role is not None: - cs.execute(f"USE ROLE {role}") - cs.execute(f"USE DATABASE {database}") - - if schema is not None: - cs.execute(f"USE SCHEMA {schema}") - - cur = cs.execute(sql) - - results = cur.fetchall() - - # Create a pandas dataframe from the results - df = pd.DataFrame(results, columns=[desc[0] for desc in cur.description]) - - return df - - global run_sql - run_sql = run_sql_snowflake + error_deprecation() def connect_to_postgres( @@ -1940,230 +389,11 @@ def connect_to_postgres( password: str = None, port: int = None, ): - """ - Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] - **Example:** - ```python - import psycopg2.connect - vn.connect_to_bigquery( - host="myhost", - dbname="mydatabase", - user="myuser", - password="mypassword", - port=5432 - ) - ``` - Args: - host (str): The postgres host. - dbname (str): The postgres database name. - user (str): The postgres user. - password (str): The postgres password. - port (int): The postgres Port. - """ - - try: - import psycopg2 - import psycopg2.extras - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method," - " run command: \npip install vanna[postgres]" - ) - - if not host: - host = os.getenv("HOST") - - if not host: - raise ImproperlyConfigured("Please set your postgres host") - - if not dbname: - dbname = os.getenv("DATABASE") - - if not dbname: - raise ImproperlyConfigured("Please set your postgres database") - - if not user: - user = os.getenv("PG_USER") - - if not user: - raise ImproperlyConfigured("Please set your postgres user") - - if not password: - password = os.getenv("PASSWORD") - - if not password: - raise ImproperlyConfigured("Please set your postgres password") - - if not port: - port = os.getenv("PORT") - - if not port: - raise ImproperlyConfigured("Please set your postgres port") - - conn = None - - try: - conn = psycopg2.connect( - host=host, - dbname=dbname, - user=user, - password=password, - port=port, - ) - except psycopg2.Error as e: - raise ValidationError(e) - - def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: - try: - with conn.cursor() as cs: # Using a with statement to manage the cursor lifecycle - cs.execute(sql) - results = cs.fetchall() - df = pd.DataFrame(results, columns=[desc[0] for desc in cs.description]) - conn.commit() - return df - except psycopg2.Error as e: - conn.rollback() - raise ValidationError(e) - except Exception as e: - conn.rollback() - raise e - - global run_sql - run_sql = run_sql_postgres + error_deprecation() def connect_to_bigquery(cred_file_path: str = None, project_id: str = None): - """ - Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] - **Example:** - ```python - import bigquery.Client - vn.connect_to_bigquery( - project_id="myprojectid", - cred_file_path="path/to/credentials.json", - ) - ``` - Args: - project_id (str): The gcs project id. - cred_file_path (str): The gcs credential file path - """ - - try: - from google.api_core.exceptions import GoogleAPIError - from google.cloud import bigquery - from google.oauth2 import service_account - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method, run command:" - " \npip install vanna[bigquery]" - ) - - if not project_id: - project_id = os.getenv("PROJECT_ID") - - if not project_id: - raise ImproperlyConfigured("Please set your Google Cloud Project ID.") - - import sys - - if "google.colab" in sys.modules: - try: - from google.colab import auth - - auth.authenticate_user() - except Exception as e: - raise ImproperlyConfigured(e) - else: - print("Not using Google Colab.") - - conn = None - - try: - conn = bigquery.Client(project=project_id) - except: - print("Could not found any google cloud implicit credentials") - - if cred_file_path: - # Validate file path and pemissions - validate_config_path(cred_file_path) - else: - if not conn: - raise ValidationError( - "Pleae provide a service account credentials json file" - ) - - if not conn: - with open(cred_file_path, "r") as f: - credentials = service_account.Credentials.from_service_account_info( - json.loads(f.read()), - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - - try: - conn = bigquery.Client(project=project_id, credentials=credentials) - except: - raise ImproperlyConfigured( - "Could not connect to bigquery please correct credentials" - ) - - def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: - if conn: - try: - job = conn.query(sql) - df = job.result().to_dataframe() - return df - except GoogleAPIError as error: - errors = [] - for error in error.errors: - errors.append(error["message"]) - raise errors - return None - - global run_sql - run_sql = run_sql_bigquery + error_deprecation() def connect_to_duckdb(url: str="memory", init_sql: str = None): - """ - Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] - - Args: - url (str): The URL of the database to connect to. - init_sql (str, optional): SQL to run when connecting to the database. Defaults to None. - - Returns: - None - """ - try: - import duckdb - except ImportError: - raise DependencyError( - "You need to install required dependencies to execute this method," - " run command: \npip install vanna[duckdb]" - ) - # URL of the database to download - if url==":memory:" or url=="": - path=":memory:" - else: - # Path to save the downloaded database - print(os.path.exists(url)) - if os.path.exists(url): - path=url - else: - path = os.path.basename(urlparse(url).path) - # Download the database if it doesn't exist - if not os.path.exists(path): - response = requests.get(url) - response.raise_for_status() # Check that the request was successful - with open(path, "wb") as f: - f.write(response.content) - - # Connect to the database - conn = duckdb.connect(path) - if init_sql: - conn.query(init_sql) - - def run_sql_duckdb(sql: str): - return conn.query(sql).to_df() - - global run_sql - run_sql = run_sql_duckdb \ No newline at end of file + error_deprecation() \ No newline at end of file diff --git a/src/vanna/base/base.py b/src/vanna/base/base.py index d2a6091a..a0f23523 100644 --- a/src/vanna/base/base.py +++ b/src/vanna/base/base.py @@ -1,3 +1,53 @@ +r""" + +# Nomenclature + +| Prefix | Definition | Examples | +| --- | --- | --- | +| `vn.get_` | Fetch some data | [`vn.get_related_ddl(...)`][vanna.base.base.VannaBase.get_related_ddl] | +| `vn.add_` | Adds something to the retrieval layer | [`vn.add_question_sql(...)`][vanna.base.base.VannaBase.add_question_sql]
[`vn.add_ddl(...)`][vanna.base.base.VannaBase.add_ddl] | +| `vn.generate_` | Generates something using AI based on the information in the model | [`vn.generate_sql(...)`][vanna.base.base.VannaBase.generate_sql]
[`vn.generate_explanation()`][vanna.base.base.VannaBase.generate_explanation] | +| `vn.run_` | Runs code (SQL) | [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] | +| `vn.remove_` | Removes something from the retrieval layer | [`vn.remove_training_data`][vanna.base.base.VannaBase.remove_training_data] | +| `vn.connect_` | Connects to a database | [`vn.connect_to_snowflake(...)`][vanna.base.base.VannaBase.connect_to_snowflake] | +| `vn.update_` | Updates something | N/A -- unused | +| `vn.set_` | Sets something | N/A -- unused | + +# Open-Source and Extending + +Vanna.AI is open-source and extensible. If you'd like to use Vanna without the servers, see an example [here](/docs/local.html). + +The following is an example of where various functions are implemented in the codebase when using the default "local" version of Vanna. `vanna.base.VannaBase` is the base class which provides a `vanna.base.VannaBase.ask` and `vanna.base.VannaBase.train` function. Those rely on abstract methods which are implemented in the subclasses `vanna.openai_chat.OpenAI_Chat` and `vanna.chromadb_vector.ChromaDB_VectorStore`. `vanna.openai_chat.OpenAI_Chat` uses the OpenAI API to generate SQL and Plotly code. `vanna.chromadb_vector.ChromaDB_VectorStore` uses ChromaDB to store training data and generate embeddings. + +If you want to use Vanna with other LLMs or databases, you can create your own subclass of `vanna.base.VannaBase` and implement the abstract methods. + +```mermaid +flowchart + subgraph VannaBase + ask + train + end + + subgraph OpenAI_Chat + get_sql_prompt + submit_prompt + generate_question + generate_plotly_code + end + + subgraph ChromaDB_VectorStore + generate_embedding + add_question_sql + add_ddl + add_documentation + get_similar_question_sql + get_related_ddl + get_related_documentation + end +``` + +""" + import json import os import re @@ -27,6 +77,31 @@ def log(self, message: str): print(message) def generate_sql(self, question: str, **kwargs) -> str: + """ + Example: + ```python + vn.generate_sql("What are the top 10 customers by sales?") + ``` + + Uses the LLM to generate a SQL query that answers a question. It runs the following methods: + + - [`get_similar_question_sql`][vanna.base.base.VannaBase.get_similar_question_sql] + + - [`get_related_ddl`][vanna.base.base.VannaBase.get_related_ddl] + + - [`get_related_documentation`][vanna.base.base.VannaBase.get_related_documentation] + + - [`get_sql_prompt`][vanna.base.base.VannaBase.get_sql_prompt] + + - [`submit_prompt`][vanna.base.base.VannaBase.submit_prompt] + + + Args: + question (str): The question to generate a SQL query for. + + Returns: + str: The SQL query that answers the question. + """ question_sql_list = self.get_similar_question_sql(question, **kwargs) ddl_list = self.get_related_ddl(question, **kwargs) doc_list = self.get_related_documentation(question, **kwargs) @@ -102,34 +177,114 @@ def generate_embedding(self, data: str, **kwargs) -> List[float]: # ----------------- Use Any Database to Store and Retrieve Context ----------------- # @abstractmethod def get_similar_question_sql(self, question: str, **kwargs) -> list: + """ + This method is used to get similar questions and their corresponding SQL statements. + + Args: + question (str): The question to get similar questions and their corresponding SQL statements for. + + Returns: + list: A list of similar questions and their corresponding SQL statements. + """ pass @abstractmethod def get_related_ddl(self, question: str, **kwargs) -> list: + """ + This method is used to get related DDL statements to a question. + + Args: + question (str): The question to get related DDL statements for. + + Returns: + list: A list of related DDL statements. + """ pass @abstractmethod def get_related_documentation(self, question: str, **kwargs) -> list: + """ + This method is used to get related documentation to a question. + + Args: + question (str): The question to get related documentation for. + + Returns: + list: A list of related documentation. + """ pass @abstractmethod def add_question_sql(self, question: str, sql: str, **kwargs) -> str: + """ + This method is used to add a question and its corresponding SQL query to the training data. + + Args: + question (str): The question to add. + sql (str): The SQL query to add. + + Returns: + str: The ID of the training data that was added. + """ pass @abstractmethod def add_ddl(self, ddl: str, **kwargs) -> str: + """ + This method is used to add a DDL statement to the training data. + + Args: + ddl (str): The DDL statement to add. + + Returns: + str: The ID of the training data that was added. + """ pass @abstractmethod def add_documentation(self, documentation: str, **kwargs) -> str: + """ + This method is used to add documentation to the training data. + + Args: + documentation (str): The documentation to add. + + Returns: + str: The ID of the training data that was added. + """ pass @abstractmethod def get_training_data(self, **kwargs) -> pd.DataFrame: + """ + Example: + ```python + vn.get_training_data() + ``` + + This method is used to get all the training data from the retrieval layer. + + Returns: + pd.DataFrame: The training data. + """ pass @abstractmethod def remove_training_data(id: str, **kwargs) -> bool: + """ + Example: + ```python + vn.remove_training_data(id="123-ddl") + ``` + + This method is used to remove training data from the retrieval layer. + + Args: + id (str): The ID of the training data to remove. + + Returns: + bool: True if the training data was removed, False otherwise. + """ pass # ----------------- Use Any Language Model API ----------------- # @@ -205,6 +360,29 @@ def get_sql_prompt( doc_list: list, **kwargs, ): + """ + Example: + ```python + vn.get_sql_prompt( + question="What are the top 10 customers by sales?", + question_sql_list=[{"question": "What are the top 10 customers by sales?", "sql": "SELECT * FROM customers ORDER BY sales DESC LIMIT 10"}], + ddl_list=["CREATE TABLE customers (id INT, name TEXT, sales DECIMAL)"], + doc_list=["The customers table contains information about customers and their sales."], + ) + + ``` + + This method is used to generate a prompt for the LLM to generate SQL. + + Args: + question (str): The question to generate SQL for. + question_sql_list (list): A list of questions and their corresponding SQL statements. + ddl_list (list): A list of DDL statements. + doc_list (list): A list of documentation. + + Returns: + any: The prompt for the LLM to generate SQL. + """ initial_prompt = "The user provides a question and you provide SQL. You will only respond with SQL code and not with any explanations.\n\nRespond with only SQL code. Do not answer with any explanations -- just the code.\n" initial_prompt = self.add_ddl_to_prompt( @@ -262,6 +440,25 @@ def get_followup_questions_prompt( @abstractmethod def submit_prompt(self, prompt, **kwargs) -> str: + """ + Example: + ```python + vn.submit_prompt( + [ + vn.system_message("The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."), + vn.user_message("What are the top 10 customers by sales?"), + ] + ) + ``` + + This method is used to submit a prompt to the LLM. + + Args: + prompt (any): The prompt to submit to the LLM. + + Returns: + str: The response from the LLM. + """ pass def generate_question(self, sql: str, **kwargs) -> str: @@ -407,7 +604,7 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame: def connect_to_sqlite(self, url: str): """ - Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] + Connect to a SQLite database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] Args: url (str): The URL of the database to connect to. @@ -447,7 +644,7 @@ def connect_to_postgres( port: int = None, ): """ - Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] + Connect to postgres using the psycopg2 connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** ```python vn.connect_to_postgres( @@ -540,7 +737,7 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]: def connect_to_bigquery(self, cred_file_path: str = None, project_id: str = None): """ - Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] + Connect to gcs using the bigquery connector. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] **Example:** ```python vn.connect_to_bigquery( @@ -629,7 +826,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]: def connect_to_duckdb(self, url: str, init_sql: str = None): """ - Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.run_sql] + Connect to a DuckDB database. This is just a helper function to set [`vn.run_sql`][vanna.base.base.VannaBase.run_sql] Args: url (str): The URL of the database to connect to. @@ -673,9 +870,23 @@ def run_sql_duckdb(sql: str): self.run_sql = run_sql_duckdb self.run_sql_is_set = True - def run_sql(sql: str, **kwargs) -> pd.DataFrame: - raise NotImplementedError( - "You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_sqlite(), vn.connect_to_postgres(), vn.connect_to_bigquery(), or vn.connect_to_duckdb()" + def run_sql(self, sql: str, **kwargs) -> pd.DataFrame: + """ + Example: + ```python + vn.run_sql("SELECT * FROM my_table") + ``` + + Run a SQL query on the connected database. + + Args: + sql (str): The SQL query to run. + + Returns: + pd.DataFrame: The results of the SQL query. + """ + raise Exception( + "You need to connect to a database first by running vn.connect_to_snowflake(), vn.connect_to_postgres(), similar function, or manually set vn.run_sql" ) def ask( @@ -801,10 +1012,10 @@ def train( Train Vanna.AI on a question and its corresponding SQL query. If you call it with no arguments, it will check if you connected to a database and it will attempt to train on the metadata of that database. - If you call it with the sql argument, it's equivalent to [`add_sql()`][vanna.add_sql]. - If you call it with the ddl argument, it's equivalent to [`add_ddl()`][vanna.add_ddl]. - If you call it with the documentation argument, it's equivalent to [`add_documentation()`][vanna.add_documentation]. - Additionally, you can pass a [`TrainingPlan`][vanna.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_experimental()`][vanna.get_training_plan_experimental]. + If you call it with the sql argument, it's equivalent to [`vn.add_question_sql()`][vanna.base.base.VannaBase.add_question_sql]. + If you call it with the ddl argument, it's equivalent to [`vn.add_ddl()`][vanna.base.base.VannaBase.add_ddl]. + If you call it with the documentation argument, it's equivalent to [`vn.add_documentation()`][vanna.base.base.VannaBase.add_documentation]. + Additionally, you can pass a [`TrainingPlan`][vanna.types.TrainingPlan] object. Get a training plan with [`vn.get_training_plan_generic()`][vanna.base.base.VannaBase.get_training_plan_generic]. Args: question (str): The question to train on. @@ -861,6 +1072,17 @@ def _get_information_schema_tables(self, database: str) -> pd.DataFrame: return df_tables def get_training_plan_generic(self, df) -> TrainingPlan: + """ + This method is used to generate a training plan from an information schema dataframe. + + Basically what it does is breaks up INFORMATION_SCHEMA.COLUMNS into groups of table/column descriptions that can be used to pass to the LLM. + + Args: + df (pd.DataFrame): The dataframe to generate the training plan from. + + Returns: + TrainingPlan: The training plan. + """ # For each of the following, we look at the df columns to see if there's a match: database_column = df.columns[ df.columns.str.lower().str.contains("database") diff --git a/src/vanna/flask.py b/src/vanna/flask.py index 5177ceaf..dcf8e713 100644 --- a/src/vanna/flask.py +++ b/src/vanna/flask.py @@ -247,6 +247,14 @@ def generate_plotly_figure(id: str, df, question, sql): def get_training_data(): df = vn.get_training_data() + if df is None or len(df) == 0: + return jsonify( + { + "type": "error", + "error": "No training data found. Please add some training data first.", + } + ) + return jsonify( { "type": "df", diff --git a/src/vanna/ollama/__init__.py b/src/vanna/ollama/__init__.py index 0f4f48e2..d0aee460 100644 --- a/src/vanna/ollama/__init__.py +++ b/src/vanna/ollama/__init__.py @@ -1,6 +1,7 @@ from ..base import VannaBase import requests import json +import re class Ollama(VannaBase): def __init__(self, config=None): @@ -23,6 +24,28 @@ def user_message(self, message: str) -> any: def assistant_message(self, message: str) -> any: return {"role": "assistant", "content": message} + def extract_sql_query(self, text): + """ + Extracts the first SQL statement after the word 'select', ignoring case, + matches until the first semicolon, three backticks, or the end of the string, + and removes three backticks if they exist in the extracted string. + + Args: + - text (str): The string to search within for an SQL statement. + + Returns: + - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found. + """ + # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string + pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL) + + match = pattern.search(text) + if match: + # Remove three backticks from the matched string if they exist + return match.group(0).replace('```', '') + else: + return text + def generate_sql(self, question: str, **kwargs) -> str: # Use the super generate_sql sql = super().generate_sql(question, **kwargs) @@ -30,7 +53,9 @@ def generate_sql(self, question: str, **kwargs) -> str: # Replace "\_" with "_" sql = sql.replace("\\_", "_") - return sql + sql = sql.replace("\\", "") + + return self.extract_sql_query(sql) def submit_prompt(self, prompt, **kwargs) -> str: url = f"{self.host}/api/chat" diff --git a/src/vanna/remote.py b/src/vanna/remote.py index d9164534..9198b0b6 100644 --- a/src/vanna/remote.py +++ b/src/vanna/remote.py @@ -93,6 +93,15 @@ def _rpc_call(self, method, params): def _dataclass_to_dict(self, obj): return dataclasses.asdict(obj) + def system_message(self, message: str) -> any: + pass + + def user_message(self, message: str) -> any: + pass + + def assistant_message(self, message: str) -> any: + pass + def get_training_data(self, **kwargs) -> pd.DataFrame: """ Get the training data for the current model diff --git a/tests/fixtures/questions.json b/tests/fixtures/questions.json deleted file mode 100644 index 6b7cb0c1..00000000 --- a/tests/fixtures/questions.json +++ /dev/null @@ -1,10 +0,0 @@ -[ - { - "question":"what are the year on year total visits on Tesla and ford from 2018 to 2023 ? Convert varchar to date using to date fucntion , Plot a line chart ", - "answer":"SELECT company_name,\r\n extract(year\r\nFROM to_date(date, 'YYYY-MM-DD')) as year, sum(total_visits) as total_visits\r\nFROM s__p_500_by_domain_and_aggregated_by_tickers_sample.datafeeds.sp_500\r\nWHERE (company_name ilike '%Tesla%'\r\n or company_name = 'Ford')\r\n and to_date(date, 'YYYY-MM-DD') between '2018-01-01'\r\n and '2023-12-31'\r\nGROUP BY company_name, extract(year\r\nFROM to_date(date, 'YYYY-MM-DD'))\r\nORDER BY company_name, year;" - }, - { - "question":"Which 10 domains received the highest amount of traffic on Black Friday in 2021 vs 2020", - "answer":"SELECT domain,\n sum(case when date = '2021-11-26' then total_visits\n else 0 end) as visits_2021,\n sum(case when date = '2020-11-27' then total_visits\n else 0 end) as visits_2020\nFROM s__p_500_by_domain_and_aggregated_by_tickers_sample.datafeeds.sp_500\nWHERE date in ('2021-11-26', '2020-11-27')\nGROUP BY domain\nORDER BY (visits_2021 - visits_2020) desc limit 10" - } -] diff --git a/tests/fixtures/sql/testSqlCreate.sql b/tests/fixtures/sql/testSqlCreate.sql deleted file mode 100644 index b9902a10..00000000 --- a/tests/fixtures/sql/testSqlCreate.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE TABLE employees (id INT, name VARCHAR(255), salary INT); diff --git a/tests/fixtures/sql/testSqlSelect.sql b/tests/fixtures/sql/testSqlSelect.sql deleted file mode 100644 index c2bd8e98..00000000 --- a/tests/fixtures/sql/testSqlSelect.sql +++ /dev/null @@ -1 +0,0 @@ -SELECT * FROM students WHERE name = 'Jane Doe'; diff --git a/tests/test_vanna.py b/tests/test_vanna.py index 9c794ca4..c70aae4f 100644 --- a/tests/test_vanna.py +++ b/tests/test_vanna.py @@ -1,505 +1,54 @@ -import vanna as vn -import requests -import sys -import io -import pandas as pd -import contextlib -import stat -import os -import pytest -from vanna.exceptions import ValidationError, ImproperlyConfigured - -endpoint_base = os.environ.get('VANNA_ENDPOINT', 'https://debug.vanna.ai') - -vn._endpoint = endpoint_base + '/rpc' -vn._unauthenticated_endpoint = endpoint_base + '/unauthenticated_rpc' - -## Helper functions -def switch_to_user(user, monkeypatch): - monkeypatch.setattr(sys, 'stdin', io.StringIO('DEBUG\n')) - - api_key = vn.get_api_key(email=f'{user}@example.com') - vn.set_api_key(api_key) - -## Tests - -def test_debug_env(): - # Get endpoint_base + '/reset' - r = requests.get(endpoint_base + '/reset') - assert r.status_code == 200 - assert r.text == 'Database reset' - -def test_create_user1(monkeypatch): - monkeypatch.setattr(sys, 'stdin', io.StringIO('DEBUG\n')) - - api_key = vn.get_api_key(email='user1@example.com') - vn.set_api_key(api_key) - - models = vn.get_models() - - assert models == ['demo-tpc-h'] - - -@pytest.mark.parametrize("model_name", ["Test @Org_"]) -def test_create_model(model_name): - rv = vn.create_model(model=model_name, db_type='Snowflake') - assert rv == True - - models = vn.get_models() - assert 'test-org' in models - - -def test_is_user1_in_model(): - rv = vn.get_models() - assert rv == ['demo-tpc-h', 'test-org'] - - -def test_is_user2_in_model(monkeypatch): - switch_to_user('user2', monkeypatch) - - models = vn.get_models() - - assert models == ['demo-tpc-h'] - -def test_switch_back_to_user1(monkeypatch): - switch_to_user('user1', monkeypatch) - - models = vn.get_models() - assert models == ['demo-tpc-h', 'test-org'] - -def test_set_model_my_model(): - with pytest.raises(ValidationError): - vn.set_model('my-model') - -def test_set_model(): - vn.set_model('test-org') - assert vn.__org == 'test-org' # type: ignore - -def test_add_user_to_model(monkeypatch): - rv = vn.add_user_to_model(model='test-org', email="user2@example.com", is_admin=False) - assert rv == True - - switch_to_user('user2', monkeypatch) - models = vn.get_models() - assert models == ['demo-tpc-h', 'test-org'] - -def test_update_model_visibility(monkeypatch): - rv = vn.update_model_visibility(public=True) - # user2 is not an admin, so this should fail - assert rv == False - - switch_to_user('user1', monkeypatch) - rv = vn.update_model_visibility(public=True) - - switch_to_user('user3', monkeypatch) - models = vn.get_models() - assert models == ['demo-tpc-h', 'test-org'] - - switch_to_user('user1', monkeypatch) - - rv = vn.update_model_visibility(public=False) - assert rv == True - - switch_to_user('user3', monkeypatch) - - models = vn.get_models() - assert models == ['demo-tpc-h'] - -def test_generate_explanation(monkeypatch): - switch_to_user('user1', monkeypatch) - rv = vn.generate_explanation(sql="SELECT * FROM students WHERE name = 'John Doe'") - assert rv == 'AI Response' - -def test_generate_question(): - rv = vn.generate_question(sql="SELECT * FROM students WHERE name = 'John Doe'") - assert rv == 'AI Response' - -def test_generate_sql(): - rv = vn.generate_sql(question="Who are the top 10 customers by Sales?") - assert rv == 'No SELECT statement could be found in the SQL code' - -def test_generate_plotly(): - data = { - 'Name': ['John', 'Emma', 'Tom', 'Emily', 'Alex'], - 'Age': [25, 28, 22, 31, 24], - 'Country': ['USA', 'Canada', 'UK', 'Australia', 'USA'], - 'Salary': [50000, 60000, 45000, 70000, 55000] - } - - # Create a dataframe from the dictionary - df = pd.DataFrame(data) - - rv = vn.generate_plotly_code(question="Who are the top 10 customers by Sales?", sql="SELECT * FROM students WHERE name = 'John Doe'", df=df) - assert rv == 'AI Response' - -def test_generate_questions(): - rv = vn.generate_questions() - assert rv == ['AI Response'] - -def test_generate_followup_questions(): - data = { - 'Name': ['John', 'Emma', 'Tom', 'Emily', 'Alex'], - 'Age': [25, 28, 22, 31, 24], - 'Country': ['USA', 'Canada', 'UK', 'Australia', 'USA'], - 'Salary': [50000, 60000, 45000, 70000, 55000] - } - - # Create a dataframe from the dictionary - df = pd.DataFrame(data) - - questions = vn.generate_followup_questions(question="Who are the top 10 customers by Sales?", df=df) - - assert questions == ['AI Response'] - -def test_add_sql(): - rv = vn.add_sql(question="What's the data about student John Doe?", sql="SELECT * FROM students WHERE name = 'John Doe'") - assert rv == True +from vanna.openai.openai_chat import OpenAI_Chat +from vanna.vannadb.vannadb_vector import VannaDB_VectorStore +from vanna.mistral.mistral import Mistral +from vanna.remote import VannaDefault - rv = vn.add_sql(question="What's the data about student Jane Doe?", sql="SELECT * FROM students WHERE name = 'Jane Doe'") - assert rv == True - -def test_generate_sql_caching(): - rv = vn.generate_sql(question="What's the data about student John Doe?") - - assert rv == 'SELECT * FROM students WHERE name = \'John Doe\'' - -def test_remove_sql(): - rv = vn.remove_sql(question="What's the data about student John Doe?") - assert rv == True - -def test_flag_sql(): - rv = vn.flag_sql_for_review(question="What's the data about student Jane Doe?") - assert rv == True - -def test_get_all_questions(): - rv = vn.get_all_questions() - assert rv.shape == (3, 5) - - vn.set_model('demo-tpc-h') - rv = vn.get_all_questions() - assert rv.shape == (0, 0) - -# def test_get_accuracy_stats(): -# rv = vn.get_accuracy_stats() -# assert rv == AccuracyStats(num_questions=2, data={'No SQL Generated': 2, 'SQL Unable to Run': 0, 'Assumed Correct': 0, 'Flagged for Review': 0, 'Reviewed and Approved': 0, 'Reviewed and Rejected': 0, 'Reviewed and Updated': 0}) - -def test_add_documentation_fail(): - rv = vn.add_documentation(documentation="This is the documentation") - assert rv == False - -def test_add_ddl_pass_fail(): - rv = vn.add_ddl(ddl="This is the ddl") - assert rv == False - -def test_add_sql_pass_fail(): - rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students") - assert rv == False - -def test_add_documentation_pass(monkeypatch): - switch_to_user('user1', monkeypatch) - vn.set_model('test-org') - rv = vn.add_documentation(documentation="This is the documentation") - assert rv == True - -def test_add_ddl_pass(): - rv = vn.add_ddl(ddl="This is the ddl") - assert rv == True - -def test_add_sql_pass(): - rv = vn.add_sql(question="How many students are there?", sql="SELECT * FROM students") - assert rv == True - -num_training_data = 4 - -def test_get_training_data(): - rv = vn.get_training_data() - assert rv.shape == (num_training_data, 4) - -def test_remove_training_data(): - training_data = vn.get_training_data() - - for index, row in training_data.iterrows(): - rv = vn.remove_training_data(row['id']) - assert rv == True - - assert vn.get_training_data().shape[0] == num_training_data-1-index - -def test_create_model_and_add_user(): - created = vn.create_model('test-org2', 'Snowflake') - assert created == True - - added = vn.add_user_to_model(model='test-org2', email="user5@example.com", is_admin=False) - assert added == True - -def test_ask_no_output(): - vn.run_sql = lambda sql: pd.DataFrame({'Name': ['John', 'Emma', 'Tom', 'Emily', 'Alex']}) - vn.generate_sql = lambda question: 'SELECT * FROM students' - vn.ask(question="How many students are there?") - -def test_ask_with_output(): - sql, df, fig, followup_questions = vn.ask(question="How many students are there?", print_results=False) - - assert sql == 'SELECT * FROM students' - - assert df.to_csv() == ',Name\n0,John\n1,Emma\n2,Tom\n3,Emily\n4,Alex\n' - -def test_generate_meta(): - meta = vn.generate_meta("What tables are available?") - - assert meta == 'AI Response' - -def test_double_train(): - vn.set_model('test-org') - - training_data = vn.get_training_data() - assert training_data.shape == (0, 0) - - trained = vn.train(question="What's the data about student John Doe?", sql="SELECT * FROM students WHERE name = 'John Doe'") - assert trained == True - - training_data = vn.get_training_data() - assert training_data.shape == (1, 4) - - vn.train(question="What's the data about student John Doe?", sql="SELECT * FROM students WHERE name = 'John Doe'") - - training_data = vn.get_training_data() - assert training_data.shape == (1, 4) - -def test_get_related_training_data(): - data = vn.get_related_training_data(question="What's the data about student John Doe?") - assert data.questions[0]['question'] == 'What is the total sales for each product?' - assert data.questions[0]['sql'] == 'SELECT * FROM ...' - assert data.ddl == ['DDL here'] - assert data.documentation == ['Documentation here'] - -@pytest.mark.parametrize("params", [ - dict( - question=None, - sql="SELECT * FROM students WHERE name = 'Jane Doe'", - documentation=False, - ddl=None, - sql_file=None, - json_file=None, - ), - dict( - question=None, - sql="SELECT * FROM students WHERE name = 'Jane Doe'", - documentation=True, - ddl=None, - sql_file=None, - json_file=None, - ), - dict( - question=None, - sql=None, - documentation=False, - ddl="This is the ddl", - sql_file=None, - json_file=None, - ), - dict( - question=None, - sql=None, - documentation=False, - ddl=None, - sql_file="tests/fixtures/sql/testSqlSelect.sql", - json_file=None, - ), - dict( - question=None, - sql=None, - documentation=False, - ddl=None, - sql_file=None, - json_file="tests/fixtures/questions.json" - ), - dict( - question=None, - sql=None, - documentation=False, - ddl=None, - sql_file="tests/fixtures/sql/testSqlCreate.sql", - json_file=None, - ), -]) -def test_train_success(monkeypatch, params): - vn.set_model('test-org') - assert vn.train(**params) - - -@pytest.mark.parametrize("params, expected_exc_class", [ - ( - dict( - question="What's the data about student John Doe?", - sql=None, - documentation=False, - ddl=None, - sql_file=None, - json_file=None, - ), - ValidationError - ), - ( - dict( - question=None, - sql=None, - documentation=False, - ddl=None, - sql_file="wrong/path/or/file.sql", - json_file=None, - ), - ImproperlyConfigured - ), - ( - dict( - question=None, - sql=None, - documentation=False, - ddl=None, - sql_file=None, - json_file="wrong/path/or/file.json", - ), - ImproperlyConfigured - ) -]) -def test_train_validations(monkeypatch, params, expected_exc_class): - vn.set_model('test-org') - - with pytest.raises((ValidationError, ImproperlyConfigured)) as exc: - vn.train(**params) - assert isinstance(exc, expected_exc_class) - - -@pytest.mark.parametrize('model_name', [1234, ['test_org']]) -def test_set_model_validation(model_name): - # test invalid model name - with pytest.raises(ValidationError) as exc: - vn.set_model(model_name) - assert "Please provide model name in string format" in exc.args[0] +import os -def mock_connector(host, dbname, user, password, port): +try: + print("Trying to load .env") + from dotenv import load_dotenv + load_dotenv() +except Exception as e: + print(f"Failed to load .env {e}") pass +MY_VANNA_MODEL = 'chinook' +MY_VANNA_API_KEY = os.environ['VANNA_API_KEY'] +OPENAI_API_KEY = os.environ['OPENAI_API_KEY'] +MISTRAL_API_KEY = os.environ['MISTRAL_API_KEY'] -@pytest.mark.parametrize('params, none_param', [ - ( - dict( - host=None, - dbname="test-db", - user="test-user", - password="test-password", - port=5432 - ), - "host" - ), - ( - dict( - host="localhost", - dbname=None, - user="test-user", - password="test-password", - port=5432, - ), - "database" - ), - ( - dict( - host="localhost", - dbname="test-db", - user=None, - password="test-password", - port=5432, - ), - "user" - ), - ( - dict( - host="localhost", - dbname="test-db", - user="test-user", - password=None, - port=5432, - ), - "password", - ), - ( - dict( - host="localhost", - dbname="test-db", - user="test-user", - password="test-password", - port=None, - ), - "port" - ), -]) -def test_connect_to_postgres_validations(monkeypatch, params, none_param): - monkeypatch.setattr("psycopg2.connect", mock_connector) - with pytest.raises(ImproperlyConfigured) as exc: - vn.connect_to_postgres(**params) - assert f"Please set your postgres {none_param}" in exc.args[0] - - -class Client: - def query(self, query): - - pass - +class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat): + def __init__(self, config=None): + VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) + OpenAI_Chat.__init__(self, config=config) -@pytest.mark.parametrize("params", [ - dict(project_id=None), -]) -def test_connect_to_bigquery_validations(monkeypatch, params): - monkeypatch.setattr("google.cloud.bigquery.Client", Client) - with pytest.raises(ImproperlyConfigured) as exc: - vn.connect_to_bigquery(**params) - assert "Please set your Google Cloud Project ID." in exc.args[0] +vn_openai = VannaOpenAI(config={'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'}) +vn_openai.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') +def test_vn_openai(): + sql = vn_openai.generate_sql("What are the top 4 customers by sales?") + df = vn_openai.run_sql(sql) + assert len(df) == 4 -@pytest.mark.parametrize("params, expected_err", [ - ( - dict( - project_id="test-project", - cred_file_path="wrong/file/path.json" - ), - "No such configuration file: wrong/file/path.json" - ), - ( - dict( - project_id="test-project", - cred_file_path="tests" - ), - "Config should be a file: tests" - ) -]) -def test_connect_to_bigquery_creds_path_validations(monkeypatch, params, expected_err): - monkeypatch.setattr("google.cloud.bigquery.Client", Client) - with pytest.raises(ImproperlyConfigured) as exc: - vn.connect_to_bigquery(**params) - assert expected_err in exc.args[0] +class VannaMistral(VannaDB_VectorStore, Mistral): + def __init__(self, config=None): + VannaDB_VectorStore.__init__(self, vanna_model=MY_VANNA_MODEL, vanna_api_key=MY_VANNA_API_KEY, config=config) + Mistral.__init__(self, config={'api_key': MISTRAL_API_KEY, 'model': 'mistral-tiny'}) +vn_mistral = VannaMistral() +vn_mistral.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') -@pytest.mark.parametrize("params", [ - dict( - project_id="test-project", - cred_file_path="tests/test-creds.json" - ), -]) -def test_connect_to_bigquery_creds_file_permissions(monkeypatch, params): - monkeypatch.setattr("google.cloud.bigquery.Client", Client) - with create_file(params["cred_file_path"]) as creds_path: - with pytest.raises(ImproperlyConfigured) as exc: - vn.connect_to_bigquery(**params) - assert f"Cannot read the config file. Please grant read privileges: {creds_path}" in exc.args[0] +def test_vn_mistral(): + sql = vn_mistral.generate_sql("What are the top 5 customers by sales?") + df = vn_mistral.run_sql(sql) + assert len(df) == 5 +vn_default = VannaDefault(model=MY_VANNA_MODEL, api_key=MY_VANNA_API_KEY) +vn_default.connect_to_sqlite('https://vanna.ai/Chinook.sqlite') -@contextlib.contextmanager -def create_file(file_path): - with open(file_path, "w") as f: - pass - os.chmod(file_path, stat.S_IWUSR) - try: - yield file_path - finally: - os.remove(file_path) +def test_vn_default(): + sql = vn_default.generate_sql("What are the top 6 customers by sales?") + df = vn_default.run_sql(sql) + assert len(df) == 6 \ No newline at end of file diff --git a/tox.ini b/tox.ini index d261900c..bd024b8b 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,7 @@ [tox] envlist = py310, + mac, flake8, [py] @@ -15,6 +16,15 @@ extras = all basepython = python3.10 commands = pytest -v --cov=tests/ --cov-report=term --cov-report=html +[testenv:mac] +deps= + {[py]deps} + python-dotenv +extras = all +basepython = python +commands = + pytest -v --cov=tests/ --cov-report=term --cov-report=html + [testenv:flake8] exclude = .tox/* deps = flake8 From 6a5f6a40ad30fd2ab1d0714f1d7fd604b8f0ade7 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Sat, 10 Feb 2024 10:27:48 -0500 Subject: [PATCH 4/8] test env --- .github/workflows/tests.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 7244091a..4dd11069 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -1,4 +1,4 @@ -name: Integration Test using Debug Server +name: Basic Integration Tests on: pull_request: {} @@ -12,6 +12,10 @@ jobs: steps: - uses: actions/checkout@v3 - name: Set up Python 3.10 + env: + VANNA_API_KEY: ${{ secrets.VANNA_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} uses: actions/setup-python@v3 with: python-version: "3.10" From 92fd82368c3a5a77ba7f00da5cdc35c33a2785a0 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Sat, 10 Feb 2024 11:32:32 -0500 Subject: [PATCH 5/8] test env --- .github/workflows/tests.yml | 8 ++++---- tox.ini | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4dd11069..173bba35 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,10 +12,6 @@ jobs: steps: - uses: actions/checkout@v3 - name: Set up Python 3.10 - env: - VANNA_API_KEY: ${{ secrets.VANNA_API_KEY }} - OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} - MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} uses: actions/setup-python@v3 with: python-version: "3.10" @@ -24,4 +20,8 @@ jobs: python -m pip install --upgrade pip pip install tox - name: Run tests + env: + VANNA_API_KEY: ${{ secrets.VANNA_API_KEY }} + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }} run: tox -e py310 diff --git a/tox.ini b/tox.ini index bd024b8b..a406f326 100644 --- a/tox.ini +++ b/tox.ini @@ -13,6 +13,7 @@ deps= deps= {[py]deps} extras = all +passenv = * basepython = python3.10 commands = pytest -v --cov=tests/ --cov-report=term --cov-report=html From 8de26dc335759cb38fb7549055411ae88f1c7ce0 Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Mon, 12 Feb 2024 16:29:35 -0500 Subject: [PATCH 6/8] v4 actions --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 173bba35..49a40f3f 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -10,9 +10,9 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v3 + uses: actions/setup-python@v4 with: python-version: "3.10" - name: Install pip From 4142cb31af1bb6d8bad10d348f4efc1a2eaec9dc Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Mon, 12 Feb 2024 16:35:06 -0500 Subject: [PATCH 7/8] setup-python v5 --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 49a40f3f..c62bd72b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -12,7 +12,7 @@ jobs: steps: - uses: actions/checkout@v4 - name: Set up Python 3.10 - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: "3.10" - name: Install pip From 9aa3a191d5142f1a822f65a855eded86560a746f Mon Sep 17 00:00:00 2001 From: Zain Hoda <7146154+zainhoda@users.noreply.github.com> Date: Mon, 12 Feb 2024 19:53:27 -0500 Subject: [PATCH 8/8] update readme --- .gitignore | 3 ++- README.md | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index d7285afd..05cf04ea 100644 --- a/.gitignore +++ b/.gitignore @@ -13,4 +13,5 @@ docs/*.html notebooks/chroma.sqlite3 dist .env -*.sqlite \ No newline at end of file +*.sqlite +htmlcov \ No newline at end of file diff --git a/README.md b/README.md index d574f4b5..7cd033d2 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ See the [base class](src/vanna/base/base.py) for more details on how this works ## User Interfaces These are some of the user interfaces that we've built using Vanna. You can use these as-is or as a starting point for your own custom interface. -- [Jupyter Notebook](https://github.com/vanna-ai/vanna/blob/main/notebooks/getting-started.ipynb) +- [Jupyter Notebook](https://vanna.ai/docs/postgres-openai-vanna-vannadb/) - [vanna-ai/vanna-streamlit](https://github.com/vanna-ai/vanna-streamlit) - [vanna-ai/vanna-flask](https://github.com/vanna-ai/vanna-flask) - [vanna-ai/vanna-slack](https://github.com/vanna-ai/vanna-slack) @@ -39,7 +39,7 @@ These are some of the user interfaces that we've built using Vanna. You can use ## Getting started See the [documentation](https://vanna.ai/docs/) for specifics on your desired database, LLM, etc. -If you want to get a feel for how it works after training, you can try this [Colab notebook](https://colab.research.google.com/github/vanna-ai/vanna/blob/main/notebooks/getting-started.ipynb). +If you want to get a feel for how it works after training, you can try this [Colab notebook](https://vanna.ai/docs/app/). ### Install