Skip to content

Commit

Permalink
Merge pull request vanna-ai#386 from Navanit-git/main
Browse files Browse the repository at this point in the history
vllm support added
  • Loading branch information
zainhoda authored and zyclove committed Apr 30, 2024
2 parents fcb69d6 + 4ad42ba commit 6555667
Show file tree
Hide file tree
Showing 6 changed files with 372 additions and 2 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi"

[project]
name = "vanna"
version = "0.4.3"
version = "0.4.3.1"
authors = [
{ name="Zain Hoda", email="zain@vanna.ai" },
]
Expand Down Expand Up @@ -42,3 +42,6 @@ gemini = ["google-generativeai"]
marqo = ["marqo"]
zhipuai = ["zhipuai"]
qdrant = ["qdrant-client"]
vllm = ["vllm"]
opensearch = ["opensearch-py", "opensearch-dsl"]

1 change: 1 addition & 0 deletions src/vanna/opensearch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .opensearch_vector import OpenSearch_VectorStore
289 changes: 289 additions & 0 deletions src/vanna/opensearch/opensearch_vector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import base64
import uuid
from typing import List

import pandas as pd
from opensearchpy import OpenSearch

from ..base import VannaBase


class OpenSearch_VectorStore(VannaBase):
def __init__(self, config=None):
VannaBase.__init__(self, config=config)
document_index = "vanna_document_index"
ddl_index = "vanna_ddl_index"
question_sql_index = "vanna_questions_sql_index"
if config is not None and "es_document_index" in config:
document_index = config["es_document_index"]
if config is not None and "es_ddl_index" in config:
ddl_index = config["es_ddl_index"]
if config is not None and "es_question_sql_index" in config:
question_sql_index = config["es_question_sql_index"]

self.document_index = document_index
self.ddl_index = ddl_index
self.question_sql_index = question_sql_index
print("OpenSearch_VectorStore initialized with document_index: ", document_index, " ddl_index: ", ddl_index, " question_sql_index: ", question_sql_index)

es_urls = None
if config is not None and "es_urls" in config:
es_urls = config["es_urls"]

# Host and port
if config is not None and "es_host" in config:
host = config["es_host"]
else:
host = "localhost"

if config is not None and "es_port" in config:
port = config["es_port"]
else:
port = 9200

if config is not None and "es_ssl" in config:
ssl = config["es_ssl"]
else:
ssl = False

if config is not None and "es_verify_certs" in config:
verify_certs = config["es_verify_certs"]
else:
verify_certs = False

# Authentication
if config is not None and "es_user" in config:
auth = (config["es_user"], config["es_password"])
else:
# Default to admin:admin
auth = None

headers = None
# base64 authentication
if config is not None and "es_encoded_base64" in config and "es_user" in config and "es_password" in config:
if config["es_encoded_base64"]:
encoded_credentials = base64.b64encode(
(config["es_user"] + ":" + config["es_password"]).encode("utf-8")
).decode("utf-8")
headers = {
'Authorization': 'Basic ' + encoded_credentials
}
# remove auth from config
auth = None

# custom headers
if config is not None and "es_headers" in config:
headers = config["es_headers"]

if config is not None and "es_timeout" in config:
timeout = config["es_timeout"]
else:
timeout = 60

if config is not None and "es_max_retries" in config:
max_retries = config["es_max_retries"]
else:
max_retries = 10

if es_urls is not None:
# Initialize the OpenSearch client by passing a list of URLs
self.client = OpenSearch(
hosts=[es_urls],
http_compress=True,
use_ssl=ssl,
verify_certs=verify_certs,
timeout=timeout,
max_retries=max_retries,
retry_on_timeout=True,
http_auth=auth,
headers=headers
)
else:
# Initialize the OpenSearch client by passing a host and port
self.client = OpenSearch(
hosts=[{'host': host, 'port': port}],
http_compress=True,
use_ssl=ssl,
verify_certs=verify_certs,
timeout=timeout,
max_retries=max_retries,
retry_on_timeout=True,
http_auth=auth,
headers=headers
)

# 执行一个简单的查询来检查连接
try:
info = self.client.info()
print('OpenSearch cluster info:', info)
except Exception as e:
print('Error connecting to OpenSearch cluster:', e)

# Create the indices if they don't exist
# self.create_index()

def create_index(self):
for index in [self.document_index, self.ddl_index, self.question_sql_index]:
try:
self.client.indices.create(index)
except Exception as e:
print("Error creating index: ", e)
print(f"opensearch index {index} already exists")
pass

def add_ddl(self, ddl: str, **kwargs) -> str:
# Assuming that you have a DDL index in your OpenSearch
id = str(uuid.uuid4()) + "-ddl"
ddl_dict = {
"ddl": ddl
}
response = self.client.index(index=self.ddl_index, body=ddl_dict, id=id,
**kwargs)
return response['_id']

def add_documentation(self, doc: str, **kwargs) -> str:
# Assuming you have a documentation index in your OpenSearch
id = str(uuid.uuid4()) + "-doc"
doc_dict = {
"doc": doc
}
response = self.client.index(index=self.document_index, id=id,
body=doc_dict, **kwargs)
return response['_id']

def add_question_sql(self, question: str, sql: str, **kwargs) -> str:
# Assuming you have a Questions and SQL index in your OpenSearch
id = str(uuid.uuid4()) + "-sql"
question_sql_dict = {
"question": question,
"sql": sql
}
response = self.client.index(index=self.question_sql_index,
body=question_sql_dict, id=id,
**kwargs)
return response['_id']

def get_related_ddl(self, question: str, **kwargs) -> List[str]:
# Assume you have some vector search mechanism associated with your data
query = {
"query": {
"match": {
"ddl": question
}
}
}
response = self.client.search(index=self.ddl_index, body=query,
**kwargs)
return [hit['_source']['ddl'] for hit in response['hits']['hits']]

def get_related_documentation(self, question: str, **kwargs) -> List[str]:
query = {
"query": {
"match": {
"doc": question
}
}
}
response = self.client.search(index=self.document_index,
body=query,
**kwargs)
return [hit['_source']['doc'] for hit in response['hits']['hits']]

def get_similar_question_sql(self, question: str, **kwargs) -> List[str]:
query = {
"query": {
"match": {
"question": question
}
}
}
response = self.client.search(index=self.question_sql_index,
body=query,
**kwargs)
return [(hit['_source']['question'], hit['_source']['sql']) for hit in
response['hits']['hits']]

def get_training_data(self, **kwargs) -> pd.DataFrame:
# This will be a simple example pulling all data from an index
# WARNING: Do not use this approach in production for large indices!
data = []
response = self.client.search(
index=self.document_index,
body={"query": {"match_all": {}}},
size=1000
)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
{
"id": hit["_id"],
"training_data_type": "documentation",
"question": "",
"content": hit["_source"]['doc'],
}
)

response = self.client.search(
index=self.question_sql_index,
body={"query": {"match_all": {}}},
size=1000
)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
{
"id": hit["_id"],
"training_data_type": "sql",
"question": hit.get("_source", {}).get("question", ""),
"content": hit.get("_source", {}).get("sql", ""),
}
)

response = self.client.search(
index=self.ddl_index,
body={"query": {"match_all": {}}},
size=1000
)
# records = [hit['_source'] for hit in response['hits']['hits']]
for hit in response['hits']['hits']:
data.append(
{
"id": hit["_id"],
"training_data_type": "ddl",
"question": "",
"content": hit["_source"]['ddl'],
}
)

return pd.DataFrame(data)

def remove_training_data(self, id: str, **kwargs) -> bool:
try:
if id.endswith("-sql"):
self.client.delete(index=self.question_sql_index, id=id)
return True
elif id.endswith("-ddl"):
self.client.delete(index=self.ddl_index, id=id, **kwargs)
return True
elif id.endswith("-doc"):
self.client.delete(index=self.document_index, id=id, **kwargs)
return True
else:
return False
except Exception as e:
print("Error deleting training dataError deleting training data: ", e)
return False

def generate_embedding(self, data: str, **kwargs) -> list[float]:
# opensearch doesn't need to generate embeddings
pass


# OpenSearch_VectorStore.__init__(self, config={'es_urls':
# "https://opensearch-node.test.com:9200", 'es_encoded_base64': True, 'es_user':
# "admin", 'es_password': "admin", 'es_verify_certs': True})


# OpenSearch_VectorStore.__init__(self, config={'es_host':
# "https://opensearch-node.test.com", 'es_port': 9200, 'es_user': "admin",
# 'es_password': "admin", 'es_verify_certs': True})
1 change: 1 addition & 0 deletions src/vanna/vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .vllm import Vllm
76 changes: 76 additions & 0 deletions src/vanna/vllm/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import re

import requests

from ..base import VannaBase


class Vllm(VannaBase):
def __init__(self, config=None):
if config is None or "vllm_host" not in config:
self.host = "http://localhost:8000"
else:
self.host = config["vllm_host"]

if config is None or "model" not in config:
raise ValueError("check the config for vllm")
else:
self.model = config["model"]

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace("```", "")
else:
return text

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)

# Replace "\_" with "_"
sql = sql.replace("\\_", "_")

sql = sql.replace("\\", "")

return self.extract_sql_query(sql)

def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/v1/chat/completions"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}

response = requests.post(url, json=data)

response_dict = response.json()

self.log(response.text)

return response_dict['choices'][0]['message']['content']
Loading

0 comments on commit 6555667

Please sign in to comment.