-
Notifications
You must be signed in to change notification settings - Fork 4.4k
/
sql_wrapper.py
232 lines (202 loc) · 9.41 KB
/
sql_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
"""SQL wrapper around SQLDatabase in langchain."""
from typing import Any, Dict, Iterable, List, Optional, Tuple
from sqlalchemy import MetaData, create_engine, insert, inspect, text
from sqlalchemy.engine import Engine
from sqlalchemy.exc import OperationalError, ProgrammingError
class SQLDatabase:
"""SQL Database.
This class provides a wrapper around the SQLAlchemy engine to interact with a SQL
database.
It provides methods to execute SQL commands, insert data into tables, and retrieve
information about the database schema.
It also supports optional features such as including or excluding specific tables,
sampling rows for table info,
including indexes in table info, and supporting views.
Based on langchain SQLDatabase.
https://github.com/langchain-ai/langchain/blob/e355606b1100097665207ca259de6dc548d44c78/libs/langchain/langchain/utilities/sql_database.py#L39
Args:
engine (Engine): The SQLAlchemy engine instance to use for database operations.
schema (Optional[str]): The name of the schema to use, if any.
metadata (Optional[MetaData]): The metadata instance to use, if any.
ignore_tables (Optional[List[str]]): List of table names to ignore. If set,
include_tables must be None.
include_tables (Optional[List[str]]): List of table names to include. If set,
ignore_tables must be None.
sample_rows_in_table_info (int): The number of sample rows to include in table
info.
indexes_in_table_info (bool): Whether to include indexes in table info.
custom_table_info (Optional[dict]): Custom table info to use.
view_support (bool): Whether to support views.
max_string_length (int): The maximum string length to use.
"""
def __init__(
self,
engine: Engine,
schema: Optional[str] = None,
metadata: Optional[MetaData] = None,
ignore_tables: Optional[List[str]] = None,
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
indexes_in_table_info: bool = False,
custom_table_info: Optional[dict] = None,
view_support: bool = False,
max_string_length: int = 300,
):
"""Create engine from database URI."""
self._engine = engine
self._schema = schema
if include_tables and ignore_tables:
raise ValueError("Cannot specify both include_tables and ignore_tables")
self._inspector = inspect(self._engine)
# including view support by adding the views as well as tables to the all
# tables list if view_support is True
self._all_tables = set(
self._inspector.get_table_names(schema=schema)
+ (self._inspector.get_view_names(schema=schema) if view_support else [])
)
self._include_tables = set(include_tables) if include_tables else set()
if self._include_tables:
missing_tables = self._include_tables - self._all_tables
if missing_tables:
raise ValueError(
f"include_tables {missing_tables} not found in database"
)
self._ignore_tables = set(ignore_tables) if ignore_tables else set()
if self._ignore_tables:
missing_tables = self._ignore_tables - self._all_tables
if missing_tables:
raise ValueError(
f"ignore_tables {missing_tables} not found in database"
)
usable_tables = self.get_usable_table_names()
self._usable_tables = set(usable_tables) if usable_tables else self._all_tables
if not isinstance(sample_rows_in_table_info, int):
raise TypeError("sample_rows_in_table_info must be an integer")
self._sample_rows_in_table_info = sample_rows_in_table_info
self._indexes_in_table_info = indexes_in_table_info
self._custom_table_info = custom_table_info
if self._custom_table_info:
if not isinstance(self._custom_table_info, dict):
raise TypeError(
"table_info must be a dictionary with table names as keys and the "
"desired table info as values"
)
# only keep the tables that are also present in the database
intersection = set(self._custom_table_info).intersection(self._all_tables)
self._custom_table_info = {
table: info
for table, info in self._custom_table_info.items()
if table in intersection
}
self._max_string_length = max_string_length
self._metadata = metadata or MetaData()
# including view support if view_support = true
self._metadata.reflect(
views=view_support,
bind=self._engine,
only=list(self._usable_tables),
schema=self._schema,
)
@property
def engine(self) -> Engine:
"""Return SQL Alchemy engine."""
return self._engine
@property
def metadata_obj(self) -> MetaData:
"""Return SQL Alchemy metadata."""
return self._metadata
@classmethod
def from_uri(
cls, database_uri: str, engine_args: Optional[dict] = None, **kwargs: Any
) -> "SQLDatabase":
"""Construct a SQLAlchemy engine from URI."""
_engine_args = engine_args or {}
return cls(create_engine(database_uri, **_engine_args), **kwargs)
@property
def dialect(self) -> str:
"""Return string representation of dialect to use."""
return self._engine.dialect.name
def get_usable_table_names(self) -> Iterable[str]:
"""Get names of tables available."""
if self._include_tables:
return sorted(self._include_tables)
return sorted(self._all_tables - self._ignore_tables)
def get_table_columns(self, table_name: str) -> List[Any]:
"""Get table columns."""
return self._inspector.get_columns(table_name)
def get_single_table_info(self, table_name: str) -> str:
"""Get table info for a single table."""
# same logic as table_info, but with specific table names
template = (
"Table '{table_name}' has columns: {columns}, "
"and foreign keys: {foreign_keys}."
)
columns = []
for column in self._inspector.get_columns(table_name, schema=self._schema):
if column.get("comment"):
columns.append(
f"{column['name']} ({column['type']!s}): "
f"'{column.get('comment')}'"
)
else:
columns.append(f"{column['name']} ({column['type']!s})")
column_str = ", ".join(columns)
foreign_keys = []
for foreign_key in self._inspector.get_foreign_keys(
table_name, schema=self._schema
):
foreign_keys.append(
f"{foreign_key['constrained_columns']} -> "
f"{foreign_key['referred_table']}.{foreign_key['referred_columns']}"
)
foreign_key_str = ", ".join(foreign_keys)
return template.format(
table_name=table_name, columns=column_str, foreign_keys=foreign_key_str
)
def insert_into_table(self, table_name: str, data: dict) -> None:
"""Insert data into a table."""
table = self._metadata.tables[table_name]
stmt = insert(table).values(**data)
with self._engine.begin() as connection:
connection.execute(stmt)
def truncate_word(self, content: Any, *, length: int, suffix: str = "...") -> str:
"""
Truncate a string to a certain number of words, based on the max string
length.
"""
if not isinstance(content, str) or length <= 0:
return content
if len(content) <= length:
return content
return content[: length - len(suffix)].rsplit(" ", 1)[0] + suffix
def run_sql(self, command: str) -> Tuple[str, Dict]:
"""Execute a SQL statement and return a string representing the results.
If the statement returns rows, a string of the results is returned.
If the statement returns no rows, an empty string is returned.
"""
with self._engine.begin() as connection:
try:
if self._schema:
command = command.replace("FROM ", f"FROM {self._schema}.")
cursor = connection.execute(text(command))
except (ProgrammingError, OperationalError) as exc:
raise NotImplementedError(
f"Statement {command!r} is invalid SQL."
) from exc
if cursor.returns_rows:
result = cursor.fetchall()
# truncate the results to the max string length
# we can't use str(result) directly because it automatically truncates long strings
truncated_results = []
for row in result:
# truncate each column, then convert the row to a tuple
truncated_row = tuple(
self.truncate_word(column, length=self._max_string_length)
for column in row
)
truncated_results.append(truncated_row)
return str(truncated_results), {
"result": truncated_results,
"col_keys": list(cursor.keys()),
}
return "", {}