-
Notifications
You must be signed in to change notification settings - Fork 0
feat: Add RetrySqlQueryCreatorTool for handling failed SQL query generation #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: base-sha/44e9c005f114a3b74ad3ceb87698d4044c012875
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||||||||||||||||
| from langchain_core.tools import StateTool | ||||||||||||||||||||
| import re | ||||||||||||||||||||
|
|
||||||||||||||||||||
| ERROR = "" | ||||||||||||||||||||
| class BaseSQLDatabaseTool(BaseModel): | ||||||||||||||||||||
| """Base tool for interacting with a SQL database.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -43,7 +44,7 @@ class Config(StateTool.Config): | |||||||||||||||||||
| description: str = """ | ||||||||||||||||||||
| Input to this tool is a detailed and correct SQL query, output is a result from the database. | ||||||||||||||||||||
| If the query is not correct, an error message will be returned. | ||||||||||||||||||||
| If an error is returned, re-run the sql_db_query_creator tool to get the correct query. | ||||||||||||||||||||
| If an error is returned, re-run the retry_sql_db_query_creator tool to get the correct query. | ||||||||||||||||||||
| """ | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__(__pydantic_self__, **data: Any) -> None: | ||||||||||||||||||||
|
|
@@ -65,6 +66,7 @@ def _run( | |||||||||||||||||||
| ) | ||||||||||||||||||||
| executable_query = executable_query.strip('\"') | ||||||||||||||||||||
| executable_query = re.sub('\\n```', '',executable_query) | ||||||||||||||||||||
| self.db.run_no_throw(executable_query) | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue: Duplicate database query execution The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||||||||||||||||||||
| return self.db.run_no_throw(executable_query) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| async def _arun( | ||||||||||||||||||||
|
|
@@ -75,14 +77,98 @@ async def _arun( | |||||||||||||||||||
| raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _extract_sql_query(self): | ||||||||||||||||||||
| for value in self.state: | ||||||||||||||||||||
| for value in reversed(self.state): | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. question (performance): Reversed iteration over state Reversing the state list might have performance implications if the list is large. Ensure this is necessary for the logic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||||||||||||||||||||
| for key, input_string in value.items(): | ||||||||||||||||||||
| if "sql_db_query_creator" in key: | ||||||||||||||||||||
| if "tool='retry_sql_db_query_creator'" in key: | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Hardcoded tool name Consider defining the tool names as constants to avoid hardcoding strings multiple times.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||||||||||||||||||||
| return input_string | ||||||||||||||||||||
| elif "tool='sql_db_query_creator'" in key: | ||||||||||||||||||||
| return input_string | ||||||||||||||||||||
| return None | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class RetrySqlQueryCreatorTool(StateTool): | ||||||||||||||||||||
| """Tool for re-creating SQL query.Use this to retry creation of sql query.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| name = "retry_sql_db_query_creator" | ||||||||||||||||||||
| description = """ | ||||||||||||||||||||
| This is a tool used to re-create sql query for user input based on the incorrect query generated and error returned from sql_db_query tool. | ||||||||||||||||||||
| Input to this tool is user prompt, incorrect sql query and error message | ||||||||||||||||||||
| Output is a sql query | ||||||||||||||||||||
| After running this tool, you can run sql_db_query tool to get the result | ||||||||||||||||||||
| """ | ||||||||||||||||||||
| sqlcreatorllm: BaseLanguageModel = Field(exclude=True) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class Config(StateTool.Config): | ||||||||||||||||||||
| """Configuration for this pydantic object.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| arbitrary_types_allowed = True | ||||||||||||||||||||
| extra = Extra.allow | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def __init__(__pydantic_self__, **data: Any) -> None: | ||||||||||||||||||||
| """Initialize the tool.""" | ||||||||||||||||||||
| super().__init__(**data) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _run( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| user_input: str, | ||||||||||||||||||||
| run_manager: Optional[CallbackManagerForToolRun] = None, | ||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||
| """Get the SQL query for the incorrect query.""" | ||||||||||||||||||||
| return self._create_sql_query(user_input) | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion: Missing type hint for method parameter Consider adding type hints for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? |
||||||||||||||||||||
|
|
||||||||||||||||||||
| async def _arun( | ||||||||||||||||||||
| self, | ||||||||||||||||||||
| table_name: str, | ||||||||||||||||||||
| run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||||||||||||||||||||
| ) -> str: | ||||||||||||||||||||
| raise NotImplementedError("SqlQueryCreatorTool does not support async") | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _create_sql_query(self,user_input): | ||||||||||||||||||||
|
|
||||||||||||||||||||
| sql_query = self._extract_sql_query() | ||||||||||||||||||||
| error_message = self._extract_error_message() | ||||||||||||||||||||
| if sql_query is None: | ||||||||||||||||||||
| return "This tool is not meant to be run directly. Start with a SQLQueryCreatorTool" | ||||||||||||||||||||
|
|
||||||||||||||||||||
| prompt_input = PromptTemplate( | ||||||||||||||||||||
| input_variables=["user_input","sql_query", "error_message"], | ||||||||||||||||||||
| template=SQL_QUERY_CREATOR_RETRY | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| sql_query = query_creator_chain.run( | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue (bug_risk): Error handling for query creation There is no error handling for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||||||||||||||||||||
| ( | ||||||||||||||||||||
| { | ||||||||||||||||||||
| "sql_query": sql_query, | ||||||||||||||||||||
| "error_message": error_message, | ||||||||||||||||||||
| "user_input": user_input | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| sql_query = sql_query.replace("```","") | ||||||||||||||||||||
| sql_query = sql_query.replace("sql","") | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🚨 issue (security): Potentially unsafe string replacement Replacing 'sql' in the query string might lead to unintended consequences if 'sql' appears in the actual query. Consider a more targeted approach. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| return sql_query | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _extract_sql_query(self): | ||||||||||||||||||||
| for value in reversed(self.state): | ||||||||||||||||||||
| for key, input_string in value.items(): | ||||||||||||||||||||
| if "tool='retry_sql_db_query_creator'" in key: | ||||||||||||||||||||
| return input_string | ||||||||||||||||||||
| elif "tool='sql_db_query_creator'" in key: | ||||||||||||||||||||
| return input_string | ||||||||||||||||||||
| return None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| def _extract_error_message(self): | ||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. issue: Error message extraction logic The method There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||||||||||||||||||||
| for value in reversed(self.state): | ||||||||||||||||||||
| for key, input_string in value.items(): | ||||||||||||||||||||
| if "tool='sql_db_query'" in key: | ||||||||||||||||||||
| if "Error" in input_string: | ||||||||||||||||||||
| return input_string | ||||||||||||||||||||
| return None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| class SqlQueryCreatorTool(StateTool): | ||||||||||||||||||||
| """Tool for creating SQL query.Use this to create sql query.""" | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -147,43 +233,24 @@ def _parse_data_model_context(self): | |||||||||||||||||||
| def _create_sql_query(self,user_input): | ||||||||||||||||||||
|
|
||||||||||||||||||||
| few_shot_examples = self._parse_few_shot_examples() | ||||||||||||||||||||
| sql_query = self._extract_sql_query() | ||||||||||||||||||||
| db_schema = self._parse_db_schema() | ||||||||||||||||||||
| data_model_context = self._parse_data_model_context() | ||||||||||||||||||||
| if sql_query is None: | ||||||||||||||||||||
| prompt_input = PromptTemplate( | ||||||||||||||||||||
| input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], | ||||||||||||||||||||
| template=self.SQL_QUERY_CREATOR_TEMPLATE, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| sql_query = query_creator_chain.run( | ||||||||||||||||||||
| ( | ||||||||||||||||||||
| { | ||||||||||||||||||||
| "db_schema": db_schema, | ||||||||||||||||||||
| "user_input": user_input, | ||||||||||||||||||||
| "few_shot_examples": few_shot_examples, | ||||||||||||||||||||
| "data_model_context": data_model_context | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| else: | ||||||||||||||||||||
| prompt_input = PromptTemplate( | ||||||||||||||||||||
| input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], | ||||||||||||||||||||
| template=SQL_QUERY_CREATOR_RETRY | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| sql_query = query_creator_chain.run( | ||||||||||||||||||||
| ( | ||||||||||||||||||||
| { | ||||||||||||||||||||
| "db_schema": db_schema, | ||||||||||||||||||||
| "user_input": user_input, | ||||||||||||||||||||
| "few_shot_examples": few_shot_examples, | ||||||||||||||||||||
| "data_model_context": data_model_context | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| prompt_input = PromptTemplate( | ||||||||||||||||||||
| input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"], | ||||||||||||||||||||
| template=self.SQL_QUERY_CREATOR_TEMPLATE, | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| sql_query = query_creator_chain.run( | ||||||||||||||||||||
| ( | ||||||||||||||||||||
| { | ||||||||||||||||||||
| "db_schema": db_schema, | ||||||||||||||||||||
| "user_input": user_input, | ||||||||||||||||||||
| "few_shot_examples": few_shot_examples, | ||||||||||||||||||||
| "data_model_context": data_model_context | ||||||||||||||||||||
| } | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| ) | ||||||||||||||||||||
| sql_query = sql_query.replace("```","") | ||||||||||||||||||||
| sql_query = sql_query.replace("sql","") | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,20 @@ | ||
|
|
||
|
|
||
| SQL_QUERY_CREATOR_RETRY = """ | ||
| You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query. | ||
| """ | ||
| Your task is convert an incorrect query resulting from user question to a correct query which is databricks sql compatible. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick (typo): Grammar issue in prompt The sentence should be 'Your task is to convert an incorrect query resulting from a user question to a correct query which is Databricks SQL compatible.' There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this comment helpful? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment type correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment area correct? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What type of LLM test could this comment become?
|
||
| Adhere to these rules: | ||
| - **Deliberately go through the question and database schema word by word** to appropriately answer the question | ||
| - **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`. | ||
| - When creating a ratio, always cast the numerator as float | ||
| ### Task: | ||
| Generate a correct SQL query that answers the question [QUESTION]`{user_input}`[/QUESTION]. | ||
| The query you will correct is: {sql_query} | ||
| The error message is: {error_message} | ||
| ### Response: | ||
| Based on your instructions, here is the SQL query I have generated | ||
| [SQL]""" | ||
|
|
||
| SQL_QUERY_CREATOR_7b = """ | ||
| ### Instructions: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
issue: Unused global variable
The
ERRORvariable is defined but never used in the code. Consider removing it if it's not needed.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this comment helpful?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment type correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment area correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What type of LLM test could this comment become?