Skip to content

Commit

Permalink
feat: upgrade langchain version
Browse files Browse the repository at this point in the history
  • Loading branch information
jczhong84 committed Feb 3, 2024
1 parent 8531df1 commit 4db3423
Show file tree
Hide file tree
Showing 23 changed files with 223 additions and 489 deletions.
7 changes: 7 additions & 0 deletions docs_website/docs/changelog/breaking_change.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ slug: /changelog

Here are the list of breaking changes that you should be aware of when updating Querybook:

## v3.31.0

Upgraded langchain to [0.1.6](https://blog.langchain.dev/langchain-v0-1-0/).

- Some langchain packages are imported from different paths, e.g. `PromptTemplate` is now from `langchain.prompts`
- Removed `StreamingWebsocketCallbackHandler` to adopt the new streaming approach.

## v3.29.0

Made below changes for `S3BaseExporter` (csv table uploader feature):
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "querybook",
"version": "3.30.0",
"version": "3.31.0",
"description": "A Big Data Webapp",
"private": true,
"scripts": {
Expand Down
1 change: 1 addition & 0 deletions querybook/config/querybook_default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ AI_ASSISTANT_CONFIG:
model_args:
model_name: ~
temperature: ~
streaming: ~
reserved_tokens: ~

EMBEDDINGS_PROVIDER: ~
Expand Down
6 changes: 0 additions & 6 deletions querybook/server/lib/ai_assistant/ai_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ def _send(self, event_type, payload: dict = None):
def send_data(self, data: dict):
self._send("data", data)

def send_delta_data(self, data: str):
self._send("delta_data", data)

def send_delta_end(self):
self._send("delta_end")

def send_tables_for_sql_gen(self, data: list[str]):
self._send("tables", data)

Expand Down
16 changes: 4 additions & 12 deletions querybook/server/lib/ai_assistant/assistants/openai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import openai
import tiktoken
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models import ChatOpenAI
from langchain_openai import ChatOpenAI

from lib.ai_assistant.base_ai_assistant import BaseAIAssistant
from lib.logger import get_logger
Expand Down Expand Up @@ -46,19 +45,12 @@ def _get_token_count(self, ai_command: str, prompt: str) -> int:
return len(encoding.encode(prompt))

def _get_error_msg(self, error) -> str:
if isinstance(error, openai.error.AuthenticationError):
if isinstance(error, openai.AuthenticationError):
return "Invalid OpenAI API key"

return super()._get_error_msg(error)

def _get_llm(self, ai_command: str, prompt_length: int, callback_handler=None):
def _get_llm(self, ai_command: str, prompt_length: int):
config = self._get_llm_config(ai_command)
if not callback_handler:
# non-streaming
return ChatOpenAI(**config)

return ChatOpenAI(
**config,
streaming=True,
callback_manager=CallbackManager([callback_handler])
)
return ChatOpenAI(**config)
90 changes: 60 additions & 30 deletions querybook/server/lib/ai_assistant/base_ai_assistant.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import functools
import json
from abc import ABC, abstractmethod

from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from pydantic import ValidationError

from app.db import with_session
from const.ai_assistant import (
DEFAUTL_TABLE_SELECT_LIMIT,
Expand All @@ -17,16 +20,15 @@
from logic.metastore import get_table_by_name
from models.metastore import DataTableColumn
from models.query_execution import QueryExecution
from pydantic.error_wrappers import ValidationError

from .ai_socket import with_ai_socket
from .ai_socket import AIWebSocket, with_ai_socket
from .prompts.sql_edit_prompt import SQL_EDIT_PROMPT
from .prompts.sql_fix_prompt import SQL_FIX_PROMPT
from .prompts.sql_summary_prompt import SQL_SUMMARY_PROMPT
from .prompts.sql_title_prompt import SQL_TITLE_PROMPT
from .prompts.table_select_prompt import TABLE_SELECT_PROMPT
from .prompts.table_summary_prompt import TABLE_SUMMARY_PROMPT
from .prompts.text_to_sql_prompt import TEXT_TO_SQL_PROMPT
from .streaming_web_socket_callback_handler import StreamingWebsocketCallbackHandler
from .tools.table_schema import (
get_slimmed_table_schemas,
get_table_schema_by_name,
Expand Down Expand Up @@ -57,11 +59,7 @@ def wrapper(self, *args, **kwargs):
except Exception as e:
LOG.error(e, exc_info=True)
err_msg = self._get_error_msg(e)
callback_handler = kwargs.get("callback_handler")
if callback_handler:
callback_handler.stream.send_error(err_msg)
else:
raise Exception(err_msg) from e
raise Exception(err_msg) from e

return wrapper

Expand Down Expand Up @@ -96,14 +94,12 @@ def _get_llm(
self,
ai_command: str,
prompt_length: int,
callback_handler: StreamingWebsocketCallbackHandler = None,
):
) -> BaseLanguageModel:
"""return the large language model to use.
Args:
ai_command (str): AI command type
prompt_length (str): The number of tokens in the prompt. Can be used to decide which model to use.
callback_handler (StreamingWebsocketCallbackHandler, optional): Callback handler to handle the straming result.
"""
raise NotImplementedError()

Expand All @@ -112,7 +108,8 @@ def _get_sql_title_prompt(self, query):

def _get_text_to_sql_prompt(self, dialect, question, table_schemas, original_query):
context_limit = self._get_usable_token_count(AICommandType.TEXT_TO_SQL.value)
prompt = TEXT_TO_SQL_PROMPT.format(
prompt_template = SQL_EDIT_PROMPT if original_query else TEXT_TO_SQL_PROMPT
prompt = prompt_template.format(
dialect=dialect,
question=question,
table_schemas=table_schemas,
Expand All @@ -122,7 +119,7 @@ def _get_text_to_sql_prompt(self, dialect, question, table_schemas, original_que

if token_count > context_limit:
# if the prompt is too long, use slimmed table schemas
prompt = TEXT_TO_SQL_PROMPT.format(
prompt = prompt_template.format(
dialect=dialect,
question=question,
table_schemas=get_slimmed_table_schemas(table_schemas),
Expand Down Expand Up @@ -184,6 +181,26 @@ def _get_query_execution_error(self, query_execution: QueryExecution) -> str:

return error[:1000]

def _run_prompt_and_send(
self,
socket: AIWebSocket,
command: AICommandType,
llm: BaseLanguageModel,
prompt_text: str,
):
"""Run the prompt and send the response to the websocket. If the command is streaming, send the response in streaming mode."""

chain = llm | JsonOutputParser()

if self._get_llm_config(command.value).get("streaming", True):
for s in chain.stream(prompt_text):
socket.send_data(s)
socket.close()
else:
response = llm.invoke(prompt_text)
socket.send_data(response)
socket.close()

@catch_error
@with_session
@with_ai_socket(command_type=AICommandType.TEXT_TO_SQL)
Expand Down Expand Up @@ -213,7 +230,9 @@ def generate_sql_query(
# not finding any relevant tables
# ask user to provide table names
socket.send_data(
"Sorry, I can't find any relevant tables by the given context. Please provide table names above."
{
"explanation": "Sorry, I can't find any relevant tables by the given context. Please provide table names above."
}
)

socket.close()
Expand All @@ -237,9 +256,14 @@ def generate_sql_query(
prompt_length=self._get_token_count(
AICommandType.TEXT_TO_SQL.value, prompt
),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)

self._run_prompt_and_send(
socket=socket,
command=AICommandType.TEXT_TO_SQL,
llm=llm,
prompt_text=prompt,
)

@catch_error
@with_ai_socket(command_type=AICommandType.SQL_TITLE)
Expand All @@ -248,16 +272,18 @@ def generate_title_from_query(self, query, socket=None):
Args:
query (str): SQL query
stream (bool, optional): Whether to stream the result. Defaults to True.
callback_handler (CallbackHandler, optional): Callback handler to handle the straming result. Required if stream is True.
"""
prompt = self._get_sql_title_prompt(query=query)
llm = self._get_llm(
ai_command=AICommandType.SQL_TITLE.value,
prompt_length=self._get_token_count(AICommandType.SQL_TITLE.value, prompt),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
self._run_prompt_and_send(
socket=socket,
command=AICommandType.SQL_TITLE,
llm=llm,
prompt_text=prompt,
)

@catch_error
@with_session
Expand All @@ -268,7 +294,7 @@ def query_auto_fix(
socket=None,
session=None,
):
"""Generate title from SQL query.
"""Fix a SQL query from the error message of a failed query execution.
Args:
query_execution_id (int): The failed query execution id
Expand Down Expand Up @@ -301,9 +327,13 @@ def query_auto_fix(
llm = self._get_llm(
ai_command=AICommandType.SQL_FIX.value,
prompt_length=self._get_token_count(AICommandType.SQL_FIX.value, prompt),
callback_handler=StreamingWebsocketCallbackHandler(socket),
)
return llm.predict(text=prompt)
self._run_prompt_and_send(
socket=socket,
command=AICommandType.SQL_FIX,
llm=llm,
prompt_text=prompt,
)

@catch_error
@with_session
Expand Down Expand Up @@ -337,9 +367,9 @@ def summarize_table(
prompt_length=self._get_token_count(
AICommandType.TABLE_SUMMARY.value, prompt
),
callback_handler=None,
)
return llm.predict(text=prompt)
chain = llm | StrOutputParser()
return chain.invoke(prompt)

@catch_error
@with_session
Expand All @@ -365,9 +395,9 @@ def summarize_query(
prompt_length=self._get_token_count(
AICommandType.SQL_SUMMARY.value, prompt
),
callback_handler=None,
)
return llm.predict(text=prompt)
chain = llm | StrOutputParser()
return chain.invoke(prompt)

@with_session
def find_tables(self, metastore_id, question, session=None):
Expand Down Expand Up @@ -422,9 +452,9 @@ def find_tables(self, metastore_id, question, session=None):
prompt_length=self._get_token_count(
AICommandType.TABLE_SELECT.value, prompt
),
callback_handler=None,
)
return json.loads(llm.predict(text=prompt))
chain = llm | JsonOutputParser()
return chain.invoke(prompt)
except Exception as e:
LOG.error(e, exc_info=True)
return []
32 changes: 32 additions & 0 deletions querybook/server/lib/ai_assistant/prompts/sql_edit_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from langchain.prompts import PromptTemplate

prompt_template = """
You are a {dialect} expert.
Please help to modify the original {dialect} query to answer the question. Your response should ONLY be based on the given context and follow the response guidelines and format instructions.
===Tables
{table_schemas}
===Original Query
{original_query}
===Response Guidelines
1. If the provided context is sufficient, please modify and generate a valid query without any explanations for the question. The query should start with a comment containing the question being asked.
2. If the provided context is insufficient, please explain why it can't be generated.
3. The original query may start with a comment containing a previously asked question. If you find such a comment, please use both the original question and the new question to modify the query, and update the comment accordingly.
4. Please use the most relevant table(s).
5. Please format the query before responding.
6. Please always respond with a valid well-formed JSON object with the following format
===Response Format
{{
"query": "A generated SQL query when context is sufficient.",
"explanation": "An explanation of failing to generate the query."
}}
===Question
{question}
"""

SQL_EDIT_PROMPT = PromptTemplate.from_template(prompt_template)
57 changes: 26 additions & 31 deletions querybook/server/lib/ai_assistant/prompts/sql_fix_prompt.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,28 @@
from langchain import PromptTemplate


prompt_template = (
"You are a SQL expert that can help fix SQL query errors.\n\n"
"Please help fix the query below based on the given error message and table schemas. \n\n"
"===SQL dialect\n"
"{dialect}\n\n"
"===Query\n"
"{query}\n\n"
"===Error\n"
"{error}\n\n"
"===Table Schemas\n"
"{table_schemas}\n\n"
"===Response Format\n"
"<@key-1@>\n"
"value-1\n\n"
"<@key-2@>\n"
"value-2\n\n"
"===Example response:\n"
"<@explanation@>\n"
"This is an explanation about the error\n\n"
"<@fix_suggestion@>\n"
"This is a recommended fix for the error\n\n"
"<@fixed_query@>\n"
"The fixed SQL query\n\n"
"===Response Guidelines\n"
"1. For the <@fixed_query@> section, it can only be a valid SQL query without any explanation.\n"
"2. If there is insufficient context to address the query error, you may leave the fixed_query section blank and provide a general suggestion instead.\n"
"3. Maintain the original query format and case in the fixed_query section, including comments, except when correcting the erroneous part.\n"
)
from langchain.prompts import PromptTemplate


prompt_template = """You are a {dialect} expert that can help fix SQL query errors.
Please help fix below {dialect} query based on the given error message and table schemas.
===Query
{query}
===Error
{error}
===Table Schemas
{table_schemas}
===Response Guidelines
1. If there is insufficient context to address the query error, please leave fixed_query blank and provide a general suggestion instead.
2. Maintain the original query format and case for the fixed_query, including comments, except when correcting the erroneous part.
===Response Format
{{
"explanation": "An explanation about the error",
"fix_suggestion": "A recommended fix for the error"",
"fixed_query": "A valid and well formatted fixed query"
}}
"""

SQL_FIX_PROMPT = PromptTemplate.from_template(prompt_template)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from langchain import PromptTemplate
from langchain.prompts import PromptTemplate


prompt_template = """
Expand Down

0 comments on commit 4db3423

Please sign in to comment.