Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from langchain_community.tools.sql_coder.tool import (
QuerySparkSQLDataBaseTool,
SqlQueryCreatorTool,
RetrySqlQueryCreatorTool
)

class SQLCoderToolkit(BaseToolkit):
Expand Down Expand Up @@ -54,6 +55,7 @@ def get_tools(self) -> List[BaseTool]:
db=self.db, description=query_sql_database_tool_description
),
QuerySQLCheckerTool(db=self.db, llm=self.llm),
RetrySqlQueryCreatorTool(sqlcreatorllm=self.sqlcreatorllm),
SqlQueryCreatorTool(
sqlcreatorllm=self.sqlcreatorllm ,
db=self.db,
Expand Down
141 changes: 104 additions & 37 deletions libs/community/langchain_community/tools/sql_coder/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from langchain_core.tools import StateTool
import re

ERROR = ""

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Unused global variable

The ERROR variable is defined but never used in the code. Consider removing it if it's not needed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

class BaseSQLDatabaseTool(BaseModel):
"""Base tool for interacting with a SQL database."""

Expand Down Expand Up @@ -43,7 +44,7 @@ class Config(StateTool.Config):
description: str = """
Input to this tool is a detailed and correct SQL query, output is a result from the database.
If the query is not correct, an error message will be returned.
If an error is returned, re-run the sql_db_query_creator tool to get the correct query.
If an error is returned, re-run the retry_sql_db_query_creator tool to get the correct query.
"""

def __init__(__pydantic_self__, **data: Any) -> None:
Expand All @@ -65,6 +66,7 @@ def _run(
)
executable_query = executable_query.strip('\"')
executable_query = re.sub('\\n```', '',executable_query)
self.db.run_no_throw(executable_query)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Duplicate database query execution

The self.db.run_no_throw(executable_query) is called twice consecutively. This seems redundant and could be removed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

return self.db.run_no_throw(executable_query)

async def _arun(
Expand All @@ -75,14 +77,98 @@ async def _arun(
raise NotImplementedError("QuerySparkSQLDataBaseTool does not support async")

def _extract_sql_query(self):
for value in self.state:
for value in reversed(self.state):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question (performance): Reversed iteration over state

Reversing the state list might have performance implications if the list is large. Ensure this is necessary for the logic.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

for key, input_string in value.items():
if "sql_db_query_creator" in key:
if "tool='retry_sql_db_query_creator'" in key:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Hardcoded tool name

Consider defining the tool names as constants to avoid hardcoding strings multiple times.

Suggested change
if "tool='retry_sql_db_query_creator'" in key:
RETRY_SQL_DB_QUERY_CREATOR = "tool='retry_sql_db_query_creator'"
SQL_DB_QUERY_CREATOR = "tool='sql_db_query_creator'"
for value in reversed(self.state):
for key, input_string in value.items():
if RETRY_SQL_DB_QUERY_CREATOR in key:
return input_string
elif SQL_DB_QUERY_CREATOR in key:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

return input_string
elif "tool='sql_db_query_creator'" in key:
return input_string
return None



class RetrySqlQueryCreatorTool(StateTool):
"""Tool for re-creating SQL query.Use this to retry creation of sql query."""

name = "retry_sql_db_query_creator"
description = """
This is a tool used to re-create sql query for user input based on the incorrect query generated and error returned from sql_db_query tool.
Input to this tool is user prompt, incorrect sql query and error message
Output is a sql query
After running this tool, you can run sql_db_query tool to get the result
"""
sqlcreatorllm: BaseLanguageModel = Field(exclude=True)


class Config(StateTool.Config):
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True
extra = Extra.allow

def __init__(__pydantic_self__, **data: Any) -> None:
"""Initialize the tool."""
super().__init__(**data)

def _run(
self,
user_input: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Get the SQL query for the incorrect query."""
return self._create_sql_query(user_input)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion: Missing type hint for method parameter

Consider adding type hints for the user_input parameter in the _run method for better code clarity.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?


async def _arun(
self,
table_name: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("SqlQueryCreatorTool does not support async")

def _create_sql_query(self,user_input):

sql_query = self._extract_sql_query()
error_message = self._extract_error_message()
if sql_query is None:
return "This tool is not meant to be run directly. Start with a SQLQueryCreatorTool"

prompt_input = PromptTemplate(
input_variables=["user_input","sql_query", "error_message"],
template=SQL_QUERY_CREATOR_RETRY
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (bug_risk): Error handling for query creation

There is no error handling for the query_creator_chain.run method. Consider adding try-except blocks to handle potential exceptions.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

(
{
"sql_query": sql_query,
"error_message": error_message,
"user_input": user_input
}
)
)
sql_query = sql_query.replace("```","")
sql_query = sql_query.replace("sql","")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🚨 issue (security): Potentially unsafe string replacement

Replacing 'sql' in the query string might lead to unintended consequences if 'sql' appears in the actual query. Consider a more targeted approach.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test


return sql_query

def _extract_sql_query(self):
for value in reversed(self.state):
for key, input_string in value.items():
if "tool='retry_sql_db_query_creator'" in key:
return input_string
elif "tool='sql_db_query_creator'" in key:
return input_string
return None

def _extract_error_message(self):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue: Error message extraction logic

The method _extract_error_message assumes that the error message will always contain 'Error'. This might not be robust for all error messages.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

for value in reversed(self.state):
for key, input_string in value.items():
if "tool='sql_db_query'" in key:
if "Error" in input_string:
return input_string
return None

class SqlQueryCreatorTool(StateTool):
"""Tool for creating SQL query.Use this to create sql query."""

Expand Down Expand Up @@ -147,43 +233,24 @@ def _parse_data_model_context(self):
def _create_sql_query(self,user_input):

few_shot_examples = self._parse_few_shot_examples()
sql_query = self._extract_sql_query()
db_schema = self._parse_db_schema()
data_model_context = self._parse_data_model_context()
if sql_query is None:
prompt_input = PromptTemplate(
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"],
template=self.SQL_QUERY_CREATOR_TEMPLATE,
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(
(
{
"db_schema": db_schema,
"user_input": user_input,
"few_shot_examples": few_shot_examples,
"data_model_context": data_model_context
}
)
)
else:
prompt_input = PromptTemplate(
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"],
template=SQL_QUERY_CREATOR_RETRY
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(
(
{
"db_schema": db_schema,
"user_input": user_input,
"few_shot_examples": few_shot_examples,
"data_model_context": data_model_context
}
)
prompt_input = PromptTemplate(
input_variables=["db_schema", "user_input", "few_shot_examples","data_model_context"],
template=self.SQL_QUERY_CREATOR_TEMPLATE,
)
query_creator_chain = LLMChain(llm=self.sqlcreatorllm, prompt=prompt_input)

sql_query = query_creator_chain.run(
(
{
"db_schema": db_schema,
"user_input": user_input,
"few_shot_examples": few_shot_examples,
"data_model_context": data_model_context
}
)
)
sql_query = sql_query.replace("```","")
sql_query = sql_query.replace("sql","")

Expand Down
16 changes: 14 additions & 2 deletions libs/langchain/langchain/tools/sqlcoder/prompt.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@


SQL_QUERY_CREATOR_RETRY = """
You have failed in the first attempt to generate correct sql query. Please try again to rewrite correct sql query.
"""
Your task is convert an incorrect query resulting from user question to a correct query which is databricks sql compatible.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick (typo): Grammar issue in prompt

The sentence should be 'Your task is to convert an incorrect query resulting from a user question to a correct query which is Databricks SQL compatible.'

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this comment helpful?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment type correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the comment area correct?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What type of LLM test could this comment become?

  • 👍 - this comment is really good/important and we should always make it
  • 👎 - this comment is really bad and we should never make it
  • no reaction - don't turn this comment into an LLM test

Adhere to these rules:
- **Deliberately go through the question and database schema word by word** to appropriately answer the question
- **Use Table Aliases** to prevent ambiguity. For example, `SELECT table1.col1, table2.col1 FROM table1 JOIN table2 ON table1.id = table2.id`.
- When creating a ratio, always cast the numerator as float
### Task:
Generate a correct SQL query that answers the question [QUESTION]`{user_input}`[/QUESTION].
The query you will correct is: {sql_query}
The error message is: {error_message}
### Response:
Based on your instructions, here is the SQL query I have generated
[SQL]"""

SQL_QUERY_CREATOR_7b = """
### Instructions:
Expand Down