Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make documentation argument name consistent #156

Merged
merged 1 commit into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.0.32"
version = "0.0.33"
authors = [
{ name="Zain Hoda", email="zain@vanna.ai" },
]
Expand Down
27 changes: 14 additions & 13 deletions src/vanna/base/base.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import json
import os
import re
import sqlite3
import traceback

from abc import ABC, abstractmethod
from typing import List, Tuple, Union
from urllib.parse import urlparse
Expand All @@ -12,7 +12,6 @@
import plotly.express as px
import plotly.graph_objects as go
import requests
import re

from ..exceptions import DependencyError, ImproperlyConfigured, ValidationError
from ..types import TrainingPlan, TrainingPlanItem
Expand Down Expand Up @@ -50,8 +49,8 @@ def generate_followup_questions(self, question: str, **kwargs) -> str:
**kwargs,
)
llm_response = self.submit_prompt(prompt, **kwargs)
numbers_removed = re.sub(r'^\d+\.\s*', '', llm_response, flags=re.MULTILINE)

numbers_removed = re.sub(r"^\d+\.\s*", "", llm_response, flags=re.MULTILINE)
return numbers_removed.split("\n")

def generate_questions(self, **kwargs) -> list[str]:
Expand All @@ -65,7 +64,7 @@ def generate_questions(self, **kwargs) -> list[str]:
"""
question_sql = self.get_similar_question_sql(question="", **kwargs)

return [q['question'] for q in question_sql]
return [q["question"] for q in question_sql]

# ----------------- Use Any Embeddings API ----------------- #
@abstractmethod
Expand Down Expand Up @@ -94,7 +93,7 @@ def add_ddl(self, ddl: str, **kwargs) -> str:
pass

@abstractmethod
def add_documentation(self, doc: str, **kwargs) -> str:
def add_documentation(self, documentation: str, **kwargs) -> str:
pass

@abstractmethod
Expand All @@ -120,12 +119,12 @@ def get_sql_prompt(

@abstractmethod
def get_followup_questions_prompt(
self,
question: str,
self,
question: str,
question_sql_list: list,
ddl_list: list,
doc_list: list,
**kwargs
doc_list: list,
**kwargs,
):
pass

Expand Down Expand Up @@ -829,9 +828,11 @@ def get_plotly_figure(
fig = ldict.get("fig", None)
except Exception as e:
# Inspect data types
numeric_cols = df.select_dtypes(include=['number']).columns.tolist()
categorical_cols = df.select_dtypes(include=['object', 'category']).columns.tolist()

numeric_cols = df.select_dtypes(include=["number"]).columns.tolist()
categorical_cols = df.select_dtypes(
include=["object", "category"]
).columns.tolist()

# Decision-making for plot type
if len(numeric_cols) >= 2:
# Use the first two numeric columns for a scatter plot
Expand Down
62 changes: 34 additions & 28 deletions src/vanna/chromadb/chromadb_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from abc import abstractmethod

import chromadb
import pandas as pd
from chromadb.config import Settings
from chromadb.utils import embedding_functions
import pandas as pd

from ..base import VannaBase

Expand Down Expand Up @@ -47,7 +47,7 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
"sql": sql,
}
)
id = str(uuid.uuid4())+"-sql"
id = str(uuid.uuid4()) + "-sql"
self.sql_collection.add(
documents=question_sql_json,
embeddings=self.generate_embedding(question_sql_json),
Expand All @@ -57,19 +57,19 @@ def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
return id

def add_ddl(self, ddl: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-ddl"
id = str(uuid.uuid4()) + "-ddl"
self.ddl_collection.add(
documents=ddl,
embeddings=self.generate_embedding(ddl),
ids=id,
)
return id

def add_documentation(self, doc: str, **kwargs) -> str:
id = str(uuid.uuid4())+"-doc"
def add_documentation(self, documentation: str, **kwargs) -> str:
id = str(uuid.uuid4()) + "-doc"
self.documentation_collection.add(
documents=doc,
embeddings=self.generate_embedding(doc),
documents=documentation,
embeddings=self.generate_embedding(documentation),
ids=id,
)
return id
Expand All @@ -81,15 +81,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

if sql_data is not None:
# Extract the documents and ids
documents = [json.loads(doc) for doc in sql_data['documents']]
ids = sql_data['ids']
documents = [json.loads(doc) for doc in sql_data["documents"]]
ids = sql_data["ids"]

# Create a DataFrame
df_sql = pd.DataFrame({
'id': ids,
'question': [doc['question'] for doc in documents],
'content': [doc['sql'] for doc in documents]
})
df_sql = pd.DataFrame(
{
"id": ids,
"question": [doc["question"] for doc in documents],
"content": [doc["sql"] for doc in documents],
}
)

df_sql["training_data_type"] = "sql"

Expand All @@ -99,15 +101,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

if ddl_data is not None:
# Extract the documents and ids
documents = [doc for doc in ddl_data['documents']]
ids = ddl_data['ids']
documents = [doc for doc in ddl_data["documents"]]
ids = ddl_data["ids"]

# Create a DataFrame
df_ddl = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})
df_ddl = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
)

df_ddl["training_data_type"] = "ddl"

Expand All @@ -117,15 +121,17 @@ def get_training_data(self, **kwargs) -> pd.DataFrame:

if doc_data is not None:
# Extract the documents and ids
documents = [doc for doc in doc_data['documents']]
ids = doc_data['ids']
documents = [doc for doc in doc_data["documents"]]
ids = doc_data["ids"]

# Create a DataFrame
df_doc = pd.DataFrame({
'id': ids,
'question': [None for doc in documents],
'content': [doc for doc in documents]
})
df_doc = pd.DataFrame(
{
"id": ids,
"question": [None for doc in documents],
"content": [doc for doc in documents],
}
)

df_doc["training_data_type"] = "documentation"

Expand Down
Loading