Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: pinecone vectorstore #426

Merged
merged 5 commits into from
May 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ dist
htmlcov
chroma.sqlite3
*.bin
.coverage.*
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "mistralai", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -45,5 +45,6 @@ zhipuai = ["zhipuai"]
ollama = ["ollama", "httpx"]
qdrant = ["qdrant-client", "fastembed"]
vllm = ["vllm"]
pinecone = ["pinecone-client", "fastembed"]
opensearch = ["opensearch-py", "opensearch-dsl"]
hf = ["transformers"]
3 changes: 3 additions & 0 deletions src/vanna/mock/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .embedding import MockEmbedding
from .llm import MockLLM
from .vectordb import MockVectorDB
11 changes: 11 additions & 0 deletions src/vanna/mock/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from typing import List

from ..base import VannaBase


class MockEmbedding(VannaBase):
def __init__(self, config=None):
pass

def generate_embedding(self, data: str, **kwargs) -> List[float]:
return [1.0, 2.0, 3.0, 4.0, 5.0]
19 changes: 19 additions & 0 deletions src/vanna/mock/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@

from ..base import VannaBase


class MockLLM(VannaBase):
def __init__(self, config=None):
pass

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def submit_prompt(self, prompt, **kwargs) -> str:
return "Mock LLM response"
55 changes: 55 additions & 0 deletions src/vanna/mock/vectordb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pandas as pd

from ..base import VannaBase


class MockVectorDB(VannaBase):
def __init__(self, config=None):
pass

def _get_id(self, value: str, **kwargs) -> str:
# Hash the value and return the ID
return str(hash(value))

def add_ddl(self, ddl: str, **kwargs) -> str:
return self._get_id(ddl)

def add_documentation(self, doc: str, **kwargs) -> str:
return self._get_id(doc)

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
return self._get_id(question)

def get_related_ddl(self, question: str, **kwargs) -> list:
return []

def get_related_documentation(self, question: str, **kwargs) -> list:
return []

def get_similar_question_sql(self, question: str, **kwargs) -> list:
return []

def get_training_data(self, **kwargs) -> pd.DataFrame:
return pd.DataFrame({'id': {0: '19546-ddl',
1: '91597-sql',
2: '133976-sql',
3: '59851-doc',
4: '73046-sql'},
'training_data_type': {0: 'ddl',
1: 'sql',
2: 'sql',
3: 'documentation',
4: 'sql'},
'question': {0: None,
1: 'What are the top selling genres?',
2: 'What are the low 7 artists by sales?',
3: None,
4: 'What is the total sales for each customer?'},
'content': {0: 'CREATE TABLE [Invoice]\n(\n [InvoiceId] INTEGER NOT NULL,\n [CustomerId] INTEGER NOT NULL,\n [InvoiceDate] DATETIME NOT NULL,\n [BillingAddress] NVARCHAR(70),\n [BillingCity] NVARCHAR(40),\n [BillingState] NVARCHAR(40),\n [BillingCountry] NVARCHAR(40),\n [BillingPostalCode] NVARCHAR(10),\n [Total] NUMERIC(10,2) NOT NULL,\n CONSTRAINT [PK_Invoice] PRIMARY KEY ([InvoiceId]),\n FOREIGN KEY ([CustomerId]) REFERENCES [Customer] ([CustomerId]) \n\t\tON DELETE NO ACTION ON UPDATE NO ACTION\n)',
1: 'SELECT g.Name AS Genre, SUM(il.Quantity) AS TotalSales\nFROM Genre g\nJOIN Track t ON g.GenreId = t.GenreId\nJOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY g.GenreId, g.Name\nORDER BY TotalSales DESC;',
2: 'SELECT a.ArtistId, a.Name, SUM(il.Quantity) AS TotalSales\nFROM Artist a\nINNER JOIN Album al ON a.ArtistId = al.ArtistId\nINNER JOIN Track t ON al.AlbumId = t.AlbumId\nINNER JOIN InvoiceLine il ON t.TrackId = il.TrackId\nGROUP BY a.ArtistId, a.Name\nORDER BY TotalSales ASC\nLIMIT 7;',
3: 'This is a SQLite database. For dates rememeber to use SQLite syntax.',
4: 'SELECT c.CustomerId, c.FirstName, c.LastName, SUM(i.Total) AS TotalSales\nFROM Customer c\nJOIN Invoice i ON c.CustomerId = i.CustomerId\nGROUP BY c.CustomerId, c.FirstName, c.LastName;'}})

def remove_training_data(id: str, **kwargs) -> bool:
return True
3 changes: 3 additions & 0 deletions src/vanna/pinecone/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .pinecone_vector import PineconeDB_VectorStore

__all__ = ["PineconeDB_VectorStore"]
275 changes: 275 additions & 0 deletions src/vanna/pinecone/pinecone_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import json
from typing import List

from pinecone import Pinecone, PodSpec, ServerlessSpec
import pandas as pd
from ..base import VannaBase
from ..utils import deterministic_uuid

from fastembed import TextEmbedding


class PineconeDB_VectorStore(VannaBase):
"""
Vectorstore using PineconeDB

Args:
config (dict): Configuration dictionary. Defaults to {}. You must provide either a Pinecone Client or an API key in the config.
- client (Pinecone, optional): Pinecone client. Defaults to None.
- api_key (str, optional): Pinecone API key. Defaults to None.
- n_results (int, optional): Number of results to return. Defaults to 10.
- dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which coresponds to the dimensions of BAAI/bge-small-en-v1.5.
- fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
- documentation_namespace (str, optional): Namespace for documentation. Defaults to "documentation".
- distance_metric (str, optional): Distance metric to use. Defaults to "cosine".
- ddl_namespace (str, optional): Namespace for DDL. Defaults to "ddl".
- sql_namespace (str, optional): Namespace for SQL. Defaults to "sql".
- index_name (str, optional): Name of the index. Defaults to "vanna-index".
- metadata_config (dict, optional): Metadata configuration if using a pinecone pod. Defaults to {}.
- server_type (str, optional): Type of Pinecone server to use. Defaults to "serverless". Options are "serverless" or "pod".
- podspec (PodSpec, optional): PodSpec configuration if using a pinecone pod. Defaults to PodSpec(environment="us-west-2", pod_type="p1.x1", metadata_config=self.metadata_config).
- serverless_spec (ServerlessSpec, optional): ServerlessSpec configuration if using a pinecone serverless index. Defaults to ServerlessSpec(cloud="aws", region="us-west-2").
Raises:
ValueError: If config is None, api_key is not provided OR client is not provided, client is not an instance of Pinecone, or server_type is not "serverless" or "pod".
"""

def __init__(self, config=None):
VannaBase.__init__(self, config=config)
if config is None:
raise ValueError(
"config is required, pass either a Pinecone client or an API key in the config."
)
client = config.get("client")
api_key = config.get("api_key")
if not api_key and not client:
raise ValueError(
"api_key is required in config or pass a configured client"
)
if not client and api_key:
self._client = Pinecone(api_key=api_key)
elif not isinstance(client, Pinecone):
raise ValueError("client must be an instance of Pinecone")
else:
self._client = client

self.n_results = config.get("n_results", 10)
self.dimensions = config.get("dimensions", 384)
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")
self.documentation_namespace = config.get(
"documentation_namespace", "documentation"
)
self.distance_metric = config.get("distance_metric", "cosine")
self.ddl_namespace = config.get("ddl_namespace", "ddl")
self.sql_namespace = config.get("sql_namespace", "sql")
self.index_name = config.get("index_name", "vanna-index")
self.metadata_config = config.get("metadata_config", {})
self.server_type = config.get("server_type", "serverless")
if self.server_type not in ["serverless", "pod"]:
raise ValueError("server_type must be either 'serverless' or 'pod'")
self.podspec = config.get(
"podspec",
PodSpec(
environment="us-west-2",
pod_type="p1.x1",
metadata_config=self.metadata_config,
),
)
self.serverless_spec = config.get(
"serverless_spec", ServerlessSpec(cloud="aws", region="us-west-2")
)
self._setup_index()

def _set_index_host(self, host: str) -> None:
self.Index = self._client.Index(host=host)

def _setup_index(self) -> None:
existing_indexes = self._get_indexes()
if self.index_name not in existing_indexes and self.server_type == "serverless":
self._client.create_index(
name=self.index_name,
dimension=self.dimensions,
metric=self.distance_metric,
spec=self.serverless_spec,
)
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
self._set_index_host(pinecone_index_host)
elif self.index_name not in existing_indexes and self.server_type == "pod":
self._client.create_index(
name=self.index_name,
dimension=self.dimensions,
metric=self.distance_metric,
spec=self.podspec,
)
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
self._set_index_host(pinecone_index_host)
else:
pinecone_index_host = self._client.describe_index(self.index_name)["host"]
self._set_index_host(pinecone_index_host)

def _get_indexes(self) -> list:
return [index["name"] for index in self._client.list_indexes()]

def _check_if_embedding_exists(self, id: str, namespace: str) -> bool:
fetch_response = self.Index.fetch(ids=[id], namespace=namespace)
if fetch_response["vectors"] == {}:
return False
return True

def add_ddl(self, ddl: str, **kwargs) -> str:
id = deterministic_uuid(ddl) + "-ddl"
if self._check_if_embedding_exists(id=id, namespace=self.ddl_namespace):
print(f"DDL with id: {id} already exists in the index. Skipping...")
return id
self.Index.upsert(
vectors=[(id, self.generate_embedding(ddl), {"ddl": ddl})],
namespace=self.ddl_namespace,
)
return id

def add_documentation(self, doc: str, **kwargs) -> str:
id = deterministic_uuid(doc) + "-doc"

if self._check_if_embedding_exists(
id=id, namespace=self.documentation_namespace
):
print(
f"Documentation with id: {id} already exists in the index. Skipping..."
)
return id
self.Index.upsert(
vectors=[(id, self.generate_embedding(doc), {"documentation": doc})],
namespace=self.documentation_namespace,
)
return id

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
question_sql_json = json.dumps(
{
"question": question,
"sql": sql,
},
ensure_ascii=False,
)
id = deterministic_uuid(question_sql_json) + "-sql"
if self._check_if_embedding_exists(id=id, namespace=self.sql_namespace):
print(
f"Question-SQL with id: {id} already exists in the index. Skipping..."
)
return id
self.Index.upsert(
vectors=[
(
id,
self.generate_embedding(question_sql_json),
{"sql": question_sql_json},
)
],
namespace=self.sql_namespace,
)
return id

def get_related_ddl(self, question: str, **kwargs) -> list:
res = self.Index.query(
namespace=self.ddl_namespace,
vector=self.generate_embedding(question),
top_k=self.n_results,
include_values=True,
include_metadata=True,
)
return [match["metadata"]["ddl"] for match in res["matches"]] if res else []

def get_related_documentation(self, question: str, **kwargs) -> list:
res = self.Index.query(
namespace=self.documentation_namespace,
vector=self.generate_embedding(question),
top_k=self.n_results,
include_values=True,
include_metadata=True,
)
return (
[match["metadata"]["documentation"] for match in res["matches"]]
if res
else []
)

def get_similar_question_sql(self, question: str, **kwargs) -> list:
res = self.Index.query(
namespace=self.sql_namespace,
vector=self.generate_embedding(question),
top_k=self.n_results,
include_values=True,
include_metadata=True,
)
return (
[
{
key: value
for key, value in json.loads(match["metadata"]["sql"]).items()
}
for match in res["matches"]
]
if res
else []
)

def get_training_data(self, **kwargs) -> pd.DataFrame:
# Pinecone does not support getting all vectors in a namespace, so we have to query for the top_k vectors with a dummy vector
df = pd.DataFrame()
namespaces = {
"sql": self.sql_namespace,
"ddl": self.ddl_namespace,
"documentation": self.documentation_namespace,
}

for data_type, namespace in namespaces.items():
data = self.Index.query(
top_k=10000, # max results that pinecone allows
namespace=namespace,
include_values=True,
include_metadata=True,
vector=[0.0] * self.dimensions,
)

if data is not None:
id_list = [match["id"] for match in data["matches"]]
content_list = [
match["metadata"][data_type] for match in data["matches"]
]
question_list = [
(
json.loads(match["metadata"][data_type])["question"]
if data_type == "sql"
else None
)
for match in data["matches"]
]

df_data = pd.DataFrame(
{
"id": id_list,
"question": question_list,
"content": content_list,
}
)
df_data["training_data_type"] = data_type
df = pd.concat([df, df_data])

return df

def remove_training_data(self, id: str, **kwargs) -> bool:
if id.endswith("-sql"):
self.Index.delete(ids=[id], namespace=self.sql_namespace)
return True
elif id.endswith("-ddl"):
self.Index.delete(ids=[id], namespace=self.ddl_namespace)
return True
elif id.endswith("-doc"):
self.Index.delete(ids=[id], namespace=self.documentation_namespace)
return True
else:
return False

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
return embedding.tolist()
Loading