Skip to content

Commit

Permalink
Merge pull request #598 from lucaordronneau/feature/azuresearch-vecto…
Browse files Browse the repository at this point in the history
…r-support

Feature/azuresearch vector support
  • Loading branch information
zainhoda authored Aug 21, 2024
2 parents c06d0ca + 926bc8f commit f854616
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ bigquery = ["google-cloud-bigquery"]
snowflake = ["snowflake-connector-python"]
duckdb = ["duckdb"]
google = ["google-generativeai", "google-cloud-aiplatform"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client"]
all = ["psycopg2-binary", "db-dtypes", "PyMySQL", "google-cloud-bigquery", "snowflake-connector-python", "duckdb", "openai", "qianfan", "mistralai>=1.0.0", "chromadb", "anthropic", "zhipuai", "marqo", "google-generativeai", "google-cloud-aiplatform", "qdrant-client", "fastembed", "ollama", "httpx", "opensearch-py", "opensearch-dsl", "transformers", "pinecone-client", "pymilvus[model]","weaviate-client", "azure-search-documents", "azure-identity", "azure-common"]
test = ["tox"]
chromadb = ["chromadb"]
openai = ["openai"]
Expand All @@ -52,3 +52,4 @@ hf = ["transformers"]
milvus = ["pymilvus[model]"]
bedrock = ["boto3", "botocore"]
weaviate = ["weaviate-client"]
azuresearch = ["azure-search-documents", "azure-identity", "azure-common", "fastembed"]
1 change: 1 addition & 0 deletions src/vanna/azuresearch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .azuresearch_vector import AzureAISearch_VectorStore
236 changes: 236 additions & 0 deletions src/vanna/azuresearch/azuresearch_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
import ast
import json
from typing import List

import pandas as pd
from azure.core.credentials import AzureKeyCredential
from azure.search.documents import SearchClient
from azure.search.documents.indexes import SearchIndexClient
from azure.search.documents.indexes.models import (
ExhaustiveKnnAlgorithmConfiguration,
ExhaustiveKnnParameters,
SearchableField,
SearchField,
SearchFieldDataType,
SearchIndex,
VectorSearch,
VectorSearchAlgorithmKind,
VectorSearchAlgorithmMetric,
VectorSearchProfile,
)
from azure.search.documents.models import VectorFilterMode, VectorizedQuery
from fastembed import TextEmbedding

from ..base import VannaBase
from ..utils import deterministic_uuid


class AzureAISearch_VectorStore(VannaBase):
"""
AzureAISearch_VectorStore is a class that provides a vector store for Azure AI Search.
Args:
config (dict): Configuration dictionary. Defaults to {}. You must provide an API key in the config.
- azure_search_endpoint (str, optional): Azure Search endpoint. Defaults to "https://azcognetive.search.windows.net".
- azure_search_api_key (str): Azure Search API key.
- dimensions (int, optional): Dimensions of the embeddings. Defaults to 384 which corresponds to the dimensions of BAAI/bge-small-en-v1.5.
- fastembed_model (str, optional): Fastembed model to use. Defaults to "BAAI/bge-small-en-v1.5".
- index_name (str, optional): Name of the index. Defaults to "vanna-index".
- n_results (int, optional): Number of results to return. Defaults to 10.
- n_results_ddl (int, optional): Number of results to return for DDL queries. Defaults to the value of n_results.
- n_results_sql (int, optional): Number of results to return for SQL queries. Defaults to the value of n_results.
- n_results_documentation (int, optional): Number of results to return for documentation queries. Defaults to the value of n_results.
Raises:
ValueError: If config is None, or if 'azure_search_api_key' is not provided in the config.
"""
def __init__(self, config=None):
VannaBase.__init__(self, config=config)

self.config = config or None

if config is None:
raise ValueError(
"config is required, pass an API key, 'azure_search_api_key', in the config."
)

azure_search_endpoint = config.get("azure_search_endpoint", "https://azcognetive.search.windows.net")
azure_search_api_key = config.get("azure_search_api_key")

self.dimensions = config.get("dimensions", 384)
self.fastembed_model = config.get("fastembed_model", "BAAI/bge-small-en-v1.5")

self.index_name = config.get("index_name", "vanna-index")

self.n_results_ddl = config.get("n_results_ddl", config.get("n_results", 10))
self.n_results_sql = config.get("n_results_sql", config.get("n_results", 10))
self.n_results_documentation = config.get("n_results_documentation", config.get("n_results", 10))

if not azure_search_api_key:
raise ValueError(
"'azure_search_api_key' is required in config to use AzureAISearch_VectorStore"
)

self.index_client = SearchIndexClient(
endpoint=azure_search_endpoint,
credential=AzureKeyCredential(azure_search_api_key)
)

self.search_client = SearchClient(
endpoint=azure_search_endpoint,
index_name=self.index_name,
credential=AzureKeyCredential(azure_search_api_key)
)

if self.index_name not in self._get_indexes():
self._create_index()

def _create_index(self) -> bool:
fields = [
SearchableField(name="id", type=SearchFieldDataType.String, key=True, filterable=True),
SearchableField(name="document", type=SearchFieldDataType.String, searchable=True, filterable=True),
SearchField(name="type", type=SearchFieldDataType.String, filterable=True, searchable=True),
SearchField(name="document_vector", type=SearchFieldDataType.Collection(SearchFieldDataType.Single), searchable=True, vector_search_dimensions=self.dimensions, vector_search_profile_name="ExhaustiveKnnProfile"),
]

vector_search = VectorSearch(
algorithms=[
ExhaustiveKnnAlgorithmConfiguration(
name="ExhaustiveKnn",
kind=VectorSearchAlgorithmKind.EXHAUSTIVE_KNN,
parameters=ExhaustiveKnnParameters(
metric=VectorSearchAlgorithmMetric.COSINE
)
)
],
profiles=[
VectorSearchProfile(
name="ExhaustiveKnnProfile",
algorithm_configuration_name="ExhaustiveKnn",
)
]
)

index = SearchIndex(name=self.index_name, fields=fields, vector_search=vector_search)
result = self.index_client.create_or_update_index(index)
print(f'{result.name} created')

def _get_indexes(self) -> list:
return [index for index in self.index_client.list_index_names()]

def add_ddl(self, ddl: str) -> str:
id = deterministic_uuid(ddl) + "-ddl"
document = {
"id": id,
"document": ddl,
"type": "ddl",
"document_vector": self.generate_embedding(ddl)
}
self.search_client.upload_documents(documents=[document])
return id

def add_documentation(self, doc: str) -> str:
id = deterministic_uuid(doc) + "-doc"
document = {
"id": id,
"document": doc,
"type": "doc",
"document_vector": self.generate_embedding(doc)
}
self.search_client.upload_documents(documents=[document])
return id

def add_question_sql(self, question: str, sql: str) -> str:
question_sql_json = json.dumps({"question": question, "sql": sql}, ensure_ascii=False)
id = deterministic_uuid(question_sql_json) + "-sql"
document = {
"id": id,
"document": question_sql_json,
"type": "sql",
"document_vector": self.generate_embedding(question_sql_json)
}
self.search_client.upload_documents(documents=[document])
return id

def get_related_ddl(self, text: str) -> List[str]:
result = []
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
df = pd.DataFrame(
self.search_client.search(
top=self.n_results_ddl,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'ddl'"
)
)

if len(df):
result = df["document"].tolist()
return result

def get_related_documentation(self, text: str) -> List[str]:
result = []
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")

df = pd.DataFrame(
self.search_client.search(
top=self.n_results_documentation,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'doc'",
vector_filter_mode=VectorFilterMode.PRE_FILTER
)
)

if len(df):
result = df["document"].tolist()
return result

def get_similar_question_sql(self, text: str) -> List[str]:
result = []
# Vectorize the text
vector_query = VectorizedQuery(vector=self.generate_embedding(text), fields="document_vector")
df = pd.DataFrame(
self.search_client.search(
top=self.n_results_sql,
vector_queries=[vector_query],
select=["id", "document", "type"],
filter=f"type eq 'sql'"
)
)

if len(df): # Check if there is similar query and the result is not empty
result = [ast.literal_eval(element) for element in df["document"].tolist()]

return result

def get_training_data(self) -> List[str]:

search = self.search_client.search(
search_text="*",
select=['id', 'document', 'type'],
filter=f"(type eq 'sql') or (type eq 'ddl') or (type eq 'doc')"
).by_page()

df = pd.DataFrame([item for page in search for item in page])

if len(df):
df.loc[df["type"] == "sql", "question"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["question"])
df.loc[df["type"] == "sql", "content"] = df.loc[df["type"] == "sql"]["document"].apply(lambda x: json.loads(x)["sql"])
df.loc[df["type"] != "sql", "content"] = df.loc[df["type"] != "sql"]["document"]

return df[["id", "question", "content", "type"]]

return pd.DataFrame()

def remove_training_data(self, id: str) -> bool:
result = self.search_client.delete_documents(documents=[{'id':id}])
return result[0].succeeded

def remove_index(self):
self.index_client.delete_index(self.index_name)

def generate_embedding(self, data: str, **kwargs) -> List[float]:
embedding_model = TextEmbedding(model_name=self.fastembed_model)
embedding = next(embedding_model.embed(data))
return embedding.tolist()
2 changes: 2 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_regular_imports():
from vanna.weaviate.weaviate_vector import WeaviateDatabase
from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings
from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore

def test_shortcut_imports():
from vanna.anthropic import Anthropic_Chat
Expand All @@ -36,3 +37,4 @@ def test_shortcut_imports():
from vanna.vllm import Vllm
from vanna.weaviate import WeaviateDatabase
from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings
from vanna.azuresearch import AzureAISearch_VectorStore
26 changes: 26 additions & 0 deletions tests/test_vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
SNOWFLAKE_ACCOUNT = os.environ['SNOWFLAKE_ACCOUNT']
SNOWFLAKE_USERNAME = os.environ['SNOWFLAKE_USERNAME']
SNOWFLAKE_PASSWORD = os.environ['SNOWFLAKE_PASSWORD']
# AZURE_SEARCH_API_KEY = os.environ['AZURE_SEARCH_API_KEY']

class VannaOpenAI(VannaDB_VectorStore, OpenAI_Chat):
def __init__(self, config=None):
Expand Down Expand Up @@ -111,6 +112,31 @@ def test_vn_chroma():
df = vn_chroma.run_sql(sql)
assert len(df) == 7

# from vanna.azuresearch.azuresearch_vector import AzureAISearch_VectorStore


# class VannaAzureSearch(AzureAISearch_VectorStore, OpenAI_Chat):
# def __init__(self, config=None):
# AzureAISearch_VectorStore.__init__(self, config=config)
# OpenAI_Chat.__init__(self, config=config)

# vn_azure_search = VannaAzureSearch(config={'azure_search_api_key': AZURE_SEARCH_API_KEY,'api_key': OPENAI_API_KEY, 'model': 'gpt-3.5-turbo'})
# vn_azure_search.connect_to_sqlite('https://vanna.ai/Chinook.sqlite')

# def test_vn_azure_search():
# existing_training_data = vn_azure_search.get_training_data()
# print(existing_training_data)
# if len(existing_training_data) > 0:
# for _, training_data in existing_training_data.iterrows():
# vn_azure_search.remove_training_data(training_data['id'])

# df_ddl = vn_azure_search.run_sql("SELECT type, sql FROM sqlite_master WHERE sql is not null")
# for ddl in df_ddl['sql'].to_list():
# vn_azure_search.train(ddl=ddl)

# sql = vn_azure_search.generate_sql("What are the top 7 customers by sales?")
# df = vn_azure_search.run_sql(sql)
# assert len(df) == 7

from vanna.milvus import Milvus_VectorStore

Expand Down

0 comments on commit f854616

Please sign in to comment.