Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vllm support added #386

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ gemini = ["google-generativeai"]
marqo = ["marqo"]
zhipuai = ["zhipuai"]
qdrant = ["qdrant-client"]
vllm = ["vllm"]
1 change: 1 addition & 0 deletions src/vanna/vllm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .vllm import Vllm
76 changes: 76 additions & 0 deletions src/vanna/vllm/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import re

import requests

from ..base import VannaBase


class Vllm(VannaBase):
def __init__(self, config=None):
if config is None or "vllm_host" not in config:
self.host = "http://localhost:8000"
else:
self.host = config["vllm_host"]

if config is None or "model" not in config:
raise ValueError("check the config for vllm")
else:
self.model = config["model"]

def system_message(self, message: str) -> any:
return {"role": "system", "content": message}

def user_message(self, message: str) -> any:
return {"role": "user", "content": message}

def assistant_message(self, message: str) -> any:
return {"role": "assistant", "content": message}

def extract_sql_query(self, text):
"""
Extracts the first SQL statement after the word 'select', ignoring case,
matches until the first semicolon, three backticks, or the end of the string,
and removes three backticks if they exist in the extracted string.
Args:
- text (str): The string to search within for an SQL statement.
Returns:
- str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
"""
# Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

match = pattern.search(text)
if match:
# Remove three backticks from the matched string if they exist
return match.group(0).replace("```", "")
else:
return text

def generate_sql(self, question: str, **kwargs) -> str:
# Use the super generate_sql
sql = super().generate_sql(question, **kwargs)

# Replace "\_" with "_"
sql = sql.replace("\\_", "_")

sql = sql.replace("\\", "")

return self.extract_sql_query(sql)

def submit_prompt(self, prompt, **kwargs) -> str:
url = f"{self.host}/v1/chat/completions"
data = {
"model": self.model,
"stream": False,
"messages": prompt,
}

response = requests.post(url, json=data)

response_dict = response.json()

self.log(response.text)

return response_dict['choices'][0]['message']['content']
2 changes: 1 addition & 1 deletion tests/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def test_regular_imports():
from vanna.ZhipuAI.ZhipuAI_Chat import ZhipuAI_Chat
from vanna.ZhipuAI.ZhipuAI_embeddings import ZhipuAI_Embeddings


def test_shortcut_imports():
from vanna.anthropic import Anthropic_Chat
from vanna.base import VannaBase
Expand All @@ -25,4 +24,5 @@ def test_shortcut_imports():
from vanna.ollama import Ollama
from vanna.openai import OpenAI_Chat, OpenAI_Embeddings
from vanna.vannadb import VannaDB_VectorStore
from vanna.vllm import Vllm
from vanna.ZhipuAI import ZhipuAI_Chat, ZhipuAI_Embeddings