Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate summaries and followup questions #271

Merged
merged 2 commits into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,3 @@ repos:
hooks:
- id: isort
args: [ "--profile", "black", "--filter-files" ]

- repo: https://github.com/odwyersoftware/brunette
rev: 238bead5ec5c58935d6bb12c70f435f70b2bf785
hooks:
- id: brunette
args: [ '--config=setup.cfg' ]
118 changes: 91 additions & 27 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class VannaBase(ABC):
def __init__(self, config=None):
self.config = config
self.run_sql_is_set = False
self.static_documentation = ""

def log(self, message: str):
print(message)
Expand Down Expand Up @@ -140,18 +141,35 @@ def is_sql_valid(self, sql: str) -> bool:
else:
return False

def generate_followup_questions(self, question: str, **kwargs) -> str:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
prompt = self.get_followup_questions_prompt(
question=question,
question_sql_list=question_sql_list,
ddl_list=ddl_list,
doc_list=doc_list,
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)
def generate_followup_questions(
self, question: str, sql: str, df: pd.DataFrame, **kwargs
) -> list:
"""
**Example:**
```python
vn.generate_followup_questions("What are the top 10 customers by sales?", df)
```

Generate a list of followup questions that you can ask Vanna.AI.

Args:
question (str): The question that was asked.
df (pd.DataFrame): The results of the SQL query.

Returns:
list: A list of followup questions that you can ask Vanna.AI.
"""

message_log = [
self.system_message(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe SQL query for this question was: {sql}\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
"Generate a list of followup questions that the user might ask about this data. Respond with a list of questions, one per line. Do not answer with any explanations -- just the questions. Remember that there should be an unambiguous SQL query that can be generated from the question. Prefer questions that are answerable outside of the context of this conversation. Prefer questions that are slight modifications of the SQL query that was generated that allow digging deeper into the data. Each question will be turned into a button that the user can click to generate a new SQL query so don't use 'example' type questions. Each question must have a one-to-one correspondence with an instantiated SQL query."
),
]

llm_response = self.submit_prompt(message_log, **kwargs)

numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")
Expand All @@ -169,6 +187,36 @@ def generate_questions(self, **kwargs) -> List[str]:

return [q["question"] for q in question_sql]

def generate_summary(self, question: str, df: pd.DataFrame, **kwargs) -> str:
"""
**Example:**
```python
vn.generate_summary("What are the top 10 customers by sales?", df)
```

Generate a summary of the results of a SQL query.

Args:
question (str): The question that was asked.
df (pd.DataFrame): The results of the SQL query.

Returns:
str: The summary of the results of the SQL query.
"""

message_log = [
self.system_message(
f"You are a helpful data assistant. The user asked the question: '{question}'\n\nThe following is a pandas DataFrame with the results of the query: \n{df.to_markdown()}\n\n"
),
self.user_message(
"Briefly summarize the data based on the question that was asked. Do not respond with any additional explanation beyond the summary."
),
]

summary = self.submit_prompt(message_log, **kwargs)

return summary

# ----------------- Use Any Embeddings API ----------------- #
@abstractmethod
def generate_embedding(self, data: str, **kwargs) -> List[float]:
Expand All @@ -184,7 +232,7 @@ def get_similar_question_sql(self, question: str, **kwargs) -> list:
question (str): The question to get similar questions and their corresponding SQL statements for.

Returns:
list: A list of similar questions and their corresponding SQL statements.
list: A list of similar questions and their corresponding SQL statements.
"""
pass

Expand Down Expand Up @@ -224,15 +272,15 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
sql (str): The SQL query to add.

Returns:
str: The ID of the training data that was added.
str: The ID of the training data that was added.
"""
pass

@abstractmethod
def add_ddl(self, ddl: str, **kwargs) -> str:
"""
This method is used to add a DDL statement to the training data.

Args:
ddl (str): The DDL statement to add.

Expand Down Expand Up @@ -265,7 +313,7 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:
This method is used to get all the training data from the retrieval layer.

Returns:
pd.DataFrame: The training data.
pd.DataFrame: The training data.
"""
pass

Expand Down Expand Up @@ -321,7 +369,10 @@ def add_ddl_to_prompt(
return initial_prompt

def add_documentation_to_prompt(
self, initial_prompt: str, documentation_list: list[str], max_tokens: int = 14000
self,
initial_prompt: str,
documentation_list: list[str],
max_tokens: int = 14000,
) -> str:
if len(documentation_list) > 0:
initial_prompt += f"\nYou may use the following documentation as a reference for what tables might be available. Use responses to past questions also to guide you:\n\n"
Expand Down Expand Up @@ -389,6 +440,9 @@ def get_sql_prompt(
initial_prompt, ddl_list, max_tokens=14000
)

if self.static_documentation != "":
doc_list.append(self.static_documentation)

initial_prompt = self.add_documentation_to_prompt(
initial_prompt, doc_list, max_tokens=14000
)
Expand Down Expand Up @@ -599,6 +653,7 @@ def run_sql_snowflake(sql: str) -> pd.DataFrame:

return df

self.static_documentation = "This is a Snowflake database"
self.run_sql = run_sql_snowflake
self.run_sql_is_set = True

Expand Down Expand Up @@ -632,6 +687,7 @@ def connect_to_sqlite(self, url: str):
def run_sql_sqlite(sql: str):
return pd.read_sql_query(sql, conn)

self.static_documentation = "This is a SQLite database"
self.run_sql = run_sql_sqlite
self.run_sql_is_set = True

Expand Down Expand Up @@ -731,11 +787,12 @@ def run_sql_postgres(sql: str) -> Union[pd.DataFrame, None]:
except psycopg2.Error as e:
conn.rollback()
raise ValidationError(e)

except Exception as e:
conn.rollback()
raise e

self.static_documentation = "This is a Postgres database"
self.run_sql_is_set = True
self.run_sql = run_sql_postgres

Expand Down Expand Up @@ -825,6 +882,7 @@ def run_sql_bigquery(sql: str) -> Union[pd.DataFrame, None]:
raise errors
return None

self.static_documentation = "This is a BigQuery database"
self.run_sql_is_set = True
self.run_sql = run_sql_bigquery

Expand All @@ -847,13 +905,13 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
" run command: \npip install vanna[duckdb]"
)
# URL of the database to download
if url==":memory:" or url=="":
path=":memory:"
if url == ":memory:" or url == "":
path = ":memory:"
else:
# Path to save the downloaded database
print(os.path.exists(url))
if os.path.exists(url):
path=url
path = url
elif url.startswith("md") or url.startswith("motherduck"):
path = url
else:
Expand All @@ -873,6 +931,7 @@ def connect_to_duckdb(self, url: str, init_sql: str = None):
def run_sql_duckdb(sql: str):
return conn.query(sql).to_df()

self.static_documentation = "This is a DuckDB database"
self.run_sql = run_sql_duckdb
self.run_sql_is_set = True

Expand All @@ -895,27 +954,31 @@ def connect_to_mssql(self, odbc_conn_str: str):
)

try:
from sqlalchemy.engine import URL
import sqlalchemy as sa
from sqlalchemy.engine import URL
except ImportError:
raise DependencyError(
"You need to install required dependencies to execute this method,"
" run command: pip install sqlalchemy"
)

connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": odbc_conn_str})
connection_url = URL.create(
"mssql+pyodbc", query={"odbc_connect": odbc_conn_str}
)

from sqlalchemy import create_engine

engine = create_engine(connection_url)

def run_sql_mssql(sql: str):
# Execute the SQL statement and return the result as a pandas DataFrame
with engine.begin() as conn:
df = pd.read_sql_query(sa.text(sql), conn)
return df

raise Exception("Couldn't run sql")

self.static_documentation = "This is a Microsoft SQL Server database"
self.run_sql = run_sql_mssql
self.run_sql_is_set = True

Expand Down Expand Up @@ -943,7 +1006,7 @@ def ask(
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
visualize: bool = True, # if False, will not generate plotly code
visualize: bool = True, # if False, will not generate plotly code
) -> Union[
Tuple[
Union[str, None],
Expand Down Expand Up @@ -1024,7 +1087,9 @@ def ask(
display = __import__(
"IPython.display", fromlist=["display"]
).display
Image = __import__("IPython.display", fromlist=["Image"]).Image
Image = __import__(
"IPython.display", fromlist=["Image"]
).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except Exception as e:
Expand Down Expand Up @@ -1377,4 +1442,3 @@ def get_plotly_figure(
fig.update_layout(template="plotly_dark")

return fig

Loading