From a56487a2242332e94afde8d74ed8b6f37cded0ac Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Thu, 14 Mar 2024 22:48:20 +0530 Subject: [PATCH 01/10] build: add mysqlclient to poetry --- poetry.lock | 22 ++++++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index e160ac765f5..b94204328f4 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. [[package]] name = "aiohttp" @@ -3289,6 +3289,24 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "mysqlclient" +version = "2.2.4" +description = "Python interface to MySQL" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mysqlclient-2.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:ac44777eab0a66c14cb0d38965572f762e193ec2e5c0723bcd11319cc5b693c5"}, + {file = "mysqlclient-2.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:329e4eec086a2336fe3541f1ce095d87a6f169d1cc8ba7b04ac68bcb234c9711"}, + {file = "mysqlclient-2.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:e1ebe3f41d152d7cb7c265349fdb7f1eca86ccb0ca24a90036cde48e00ceb2ab"}, + {file = "mysqlclient-2.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:3c318755e06df599338dad7625f884b8a71fcf322a9939ef78c9b3db93e1de7a"}, + {file = "mysqlclient-2.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:9d4c015480c4a6b2b1602eccd9846103fc70606244788d04aa14b31c4bd1f0e2"}, + {file = "mysqlclient-2.2.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d43987bb9626096a302ca6ddcdd81feaeca65ced1d5fe892a6a66b808326aa54"}, + {file = "mysqlclient-2.2.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4e80dcad884dd6e14949ac6daf769123223a52a6805345608bf49cdaf7bc8b3a"}, + {file = "mysqlclient-2.2.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9d3310295cb682232cadc28abd172f406c718b9ada41d2371259098ae37779d3"}, + {file = "mysqlclient-2.2.4.tar.gz", hash = "sha256:33bc9fb3464e7d7c10b1eaf7336c5ff8f2a3d3b88bab432116ad2490beb3bf41"}, +] + [[package]] name = "ncclient" version = "0.6.15" @@ -7838,4 +7856,4 @@ networking = ["junos-eznc"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "76a93f37924c0df19297e96c43dcd8917ad25f04ff1194e1a00a58982e0871c9" +content-hash = "0defe7f936ad04209532f761048b763c18f8f1b83bcf2da1b561c6cc37ae449b" diff --git a/pyproject.toml b/pyproject.toml index 5cb280fdf5d..2aa4112ae36 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,6 +128,7 @@ croniter = "^2.0.1" drf-yaml = "^3.0.1" google-generativeai = "^0.3.1" scrapy-playwright = "^0.0.33" +mysqlclient = "^2.2.4" [tool.poetry.extras] networking = ["junos-eznc"] From 0397c3610d14f5476ddc784f6215b12b5eab454c Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Thu, 14 Mar 2024 22:48:38 +0530 Subject: [PATCH 02/10] chore: update vscode settings --- .vscode/settings.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 08bb3f6cf76..15a08cbb35e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,7 +4,7 @@ "editor.tabSize": 4, "editor.defaultFormatter": "ms-python.black-formatter", "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" } }, "isort.args": ["--profile", "black"], From 0a8ed149858848f4ad3881983579a9126cbfe52b Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Thu, 14 Mar 2024 22:49:08 +0530 Subject: [PATCH 03/10] fix: remove redundant url entry --- llmstack/datasources/urls.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llmstack/datasources/urls.py b/llmstack/datasources/urls.py index fc42285edc3..3c4cabbc4c9 100644 --- a/llmstack/datasources/urls.py +++ b/llmstack/datasources/urls.py @@ -8,10 +8,6 @@ "api/datasource_types", apis.DataSourceTypeViewSet.as_view({"get": "get"}), ), - path( - "api/datasource_types", - apis.DataSourceTypeViewSet.as_view({"get": "get"}), - ), # Data sources path( "api/datasources", From c80ebb2acc080ff940bc4469b3f964749ba49a7c Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Thu, 14 Mar 2024 22:50:31 +0530 Subject: [PATCH 04/10] feat: add support for sqlalchemy based data stores --- .../read.py => database/database_reader.py} | 88 +++++++----------- .../blocks/data/store/database/mysql.py | 85 +++++++++++++++++ .../__init__.py => database/postgresql.py} | 10 +- .../__init__.py => database/sqlite.py} | 3 +- .../blocks/data/store/database/utils.py | 83 +++++++++++++++++ .../common/blocks/data/store/sqlite/read.py | 77 --------------- .../store/{postgres => database}/__init__.py | 0 .../data/store/{sqlite => database}/sample.db | Bin .../store/database/test_database_reader.py | 65 +++++++++++++ .../blocks/data/store/sqlite/__init__.py | 0 .../blocks/data/store/sqlite/test_read.py | 21 ----- 11 files changed, 274 insertions(+), 158 deletions(-) rename llmstack/common/blocks/data/store/{postgres/read.py => database/database_reader.py} (55%) create mode 100644 llmstack/common/blocks/data/store/database/mysql.py rename llmstack/common/blocks/data/store/{postgres/__init__.py => database/postgresql.py} (91%) rename llmstack/common/blocks/data/store/{sqlite/__init__.py => database/sqlite.py} (76%) create mode 100644 llmstack/common/blocks/data/store/database/utils.py delete mode 100644 llmstack/common/blocks/data/store/sqlite/read.py rename llmstack/common/tests/blocks/data/store/{postgres => database}/__init__.py (100%) rename llmstack/common/tests/blocks/data/store/{sqlite => database}/sample.db (100%) create mode 100644 llmstack/common/tests/blocks/data/store/database/test_database_reader.py delete mode 100644 llmstack/common/tests/blocks/data/store/sqlite/__init__.py delete mode 100644 llmstack/common/tests/blocks/data/store/sqlite/test_read.py diff --git a/llmstack/common/blocks/data/store/postgres/read.py b/llmstack/common/blocks/data/store/database/database_reader.py similarity index 55% rename from llmstack/common/blocks/data/store/postgres/read.py rename to llmstack/common/blocks/data/store/database/database_reader.py index 93e486db4da..2b3fd9946cd 100644 --- a/llmstack/common/blocks/data/store/postgres/read.py +++ b/llmstack/common/blocks/data/store/database/database_reader.py @@ -1,50 +1,22 @@ +import collections +import datetime import json -from collections import defaultdict -from datetime import datetime -from uuid import UUID +import uuid +import sqlalchemy from psycopg2.extras import Range from llmstack.common.blocks.base.processor import ProcessorInterface from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument -from llmstack.common.blocks.data.store.postgres import ( - PostgresConfiguration, - PostgresOutput, - get_pg_connection, +from llmstack.common.blocks.data.store.database.utils import ( + DatabaseConfiguration, + DatabaseOutput, + get_database_connection, ) -class PostgresReaderInput(BaseSchema): - sql: str - - -types_map = { - 20: "integer", - 21: "integer", - 23: "integer", - 700: "float", - 1700: "float", - 701: "float", - 16: "boolean", - 1082: "date", - 1182: "date", - 1114: "datetime", - 1184: "datetime", - 1115: "datetime", - 1185: "datetime", - 1014: "string", - 1015: "string", - 1008: "string", - 1009: "string", - 2951: "string", - 1043: "string", - 1002: "string", - 1003: "string", -} - - -class PostgreSQLJSONEncoder(json.JSONEncoder): +class DatabaseJSONEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, Range): # From: https://github.com/psycopg/psycopg2/pull/779 @@ -64,20 +36,24 @@ def default(self, o): ] return "".join(items) - elif isinstance(o, UUID): + elif isinstance(o, uuid.UUID): return str(o.hex) - elif isinstance(o, datetime): + elif isinstance(o, (datetime.date, datetime.datetime)): return o.isoformat() - return super(PostgreSQLJSONEncoder, self).default(o) + return super().default(o) -class PostgresReader( - ProcessorInterface[PostgresReaderInput, PostgresOutput, PostgresConfiguration], +class DatabaseReaderInput(BaseSchema): + sql: str + + +class DatabaseReader( + ProcessorInterface[DatabaseReaderInput, DatabaseOutput, DatabaseConfiguration], ): def fetch_columns(self, columns): column_names = set() - duplicates_counters = defaultdict(int) + duplicates_counters = collections.defaultdict(int) new_columns = [] for col in columns: @@ -90,35 +66,35 @@ def fetch_columns(self, columns): ) column_names.add(column_name) - new_columns.append( - {"name": column_name, "friendly_name": column_name, "type": col[1]}, - ) + new_columns.append({"name": column_name, "type": col[1]}) return new_columns def process( self, - input: PostgresReaderInput, - configuration: PostgresConfiguration, - ) -> PostgresOutput: - connection = get_pg_connection(configuration.dict()) - cursor = connection.cursor() + input: DatabaseReaderInput, + configuration: DatabaseConfiguration, + ) -> DatabaseOutput: + connection = get_database_connection(configuration=configuration) try: - cursor.execute(input.sql) + result = connection.execute(sqlalchemy.text(input.sql)) + cursor = result.cursor + if cursor.description is not None: columns = self.fetch_columns( - [(i[0], types_map.get(i[1], None)) for i in cursor.description], + [(i[0], None) for i in cursor.description], ) rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor] data = {"columns": columns, "rows": rows} - json_data = json.dumps(data, cls=PostgreSQLJSONEncoder) + json_data = json.dumps(data, cls=DatabaseJSONEncoder) else: raise Exception("Query completed but it returned no data.") except Exception as e: - connection.cancel() + connection.close() + connection.engine.dispose() raise e - return PostgresOutput( + return DatabaseOutput( documents=[ DataDocument( content=json_data, diff --git a/llmstack/common/blocks/data/store/database/mysql.py b/llmstack/common/blocks/data/store/database/mysql.py new file mode 100644 index 00000000000..b3f4cd7be31 --- /dev/null +++ b/llmstack/common/blocks/data/store/database/mysql.py @@ -0,0 +1,85 @@ +from enum import Enum +from typing import ClassVar, List, Optional + +from llmstack.common.blocks.base.schema import BaseSchema +from llmstack.common.blocks.data import DataDocument + +try: + import MySQLdb + + enabled = True +except ImportError: + enabled = False + + +class SSLMode(str, Enum): + disabled = "DISABLED" + preferred = "PREFERRED" + required = "REQUIRED" + verify_ca = "VERIFY_CA" + verify_identity = "VERIFY_IDENTITY" + + +class MySQLConfiguration(BaseSchema): + engine: ClassVar[str] = "mysql" + user: Optional[str] + password: Optional[str] + host: str = "127.0.0.1" + port: int = 3306 + dbname: str + use_ssl: bool = False + sslmode: SSLMode = "preferred" + ssl_ca: Optional[str] = None + ssl_cert: Optional[str] = None + ssl_key: Optional[str] = None + + class Config: + schema_extra = { + "order": ["host", "port", "user", "password"], + "required": ["dbname"], + "secret": ["password", "ssl_ca", "ssl_cert", "ssl_key"], + "extra_options": ["sslmode", "ssl_ca", "ssl_cert", "ssl_key"], + } + + +class MySQLOutput(BaseSchema): + documents: List[DataDocument] + + +def get_mysql_ssl_config(configuration: dict): + if not configuration.get("use_ssl"): + return {} + + ssl_config = {"sslmode": configuration.get("sslmode", "prefer")} + + if configuration.get("use_ssl"): + config_map = {"ssl_mode": "preferred", "ssl_cacert": "ca", "ssl_cert": "cert", "ssl_key": "key"} + for key, cfg in config_map.items(): + val = configuration.get(key) + if val: + ssl_config[cfg] = val + + return ssl_config + + +def get_mysql_connection(configuration: dict): + params = dict( + host=configuration.get("host"), + user=configuration.get("user"), + passwd=configuration.get("password"), + db=configuration.get("dbname"), + port=configuration.get("port", 3306), + charset=configuration.get("charset", "utf8"), + use_unicode=configuration.get("use_unicode", True), + connect_timeout=configuration.get("connect_timeout", 60), + autocommit=configuration.get("autocommit", True), + ) + + ssl_options = get_mysql_ssl_config() + + if ssl_options: + params["ssl"] = ssl_options + + connection = MySQLdb.connect(**params) + + return connection diff --git a/llmstack/common/blocks/data/store/postgres/__init__.py b/llmstack/common/blocks/data/store/database/postgresql.py similarity index 91% rename from llmstack/common/blocks/data/store/postgres/__init__.py rename to llmstack/common/blocks/data/store/database/postgresql.py index d8a322169d9..93988e65340 100644 --- a/llmstack/common/blocks/data/store/postgres/__init__.py +++ b/llmstack/common/blocks/data/store/database/postgresql.py @@ -1,7 +1,7 @@ from base64 import b64decode from enum import Enum from tempfile import NamedTemporaryFile -from typing import List, Optional +from typing import ClassVar, List, Optional import psycopg2 @@ -19,6 +19,7 @@ class SSLMode(str, Enum): class PostgresConfiguration(BaseSchema): + engine: ClassVar[str] = "postgresql" user: Optional[str] password: Optional[str] host: str = "127.0.0.1" @@ -53,7 +54,10 @@ def _create_cert_file(configuration, key, ssl_config): ssl_config[key] = cert_file.name -def _get_ssl_config(configuration: dict): +def get_pg_ssl_config(configuration: dict): + if not configuration.get("use_ssl"): + return {} + ssl_config = {"sslmode": configuration.get("sslmode", "prefer")} _create_cert_file(configuration, "sslrootcert", ssl_config) @@ -65,7 +69,7 @@ def _get_ssl_config(configuration: dict): def get_pg_connection(configuration: dict): ssl_config = ( - _get_ssl_config( + get_pg_ssl_config( configuration, ) if configuration.get("use_ssl") diff --git a/llmstack/common/blocks/data/store/sqlite/__init__.py b/llmstack/common/blocks/data/store/database/sqlite.py similarity index 76% rename from llmstack/common/blocks/data/store/sqlite/__init__.py rename to llmstack/common/blocks/data/store/database/sqlite.py index 324fa2ea71e..c190c9a0b55 100644 --- a/llmstack/common/blocks/data/store/sqlite/__init__.py +++ b/llmstack/common/blocks/data/store/database/sqlite.py @@ -1,10 +1,11 @@ -from typing import List +from typing import ClassVar, List from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument class SQLiteConfiguration(BaseSchema): + engine: ClassVar[str] = "sqlite" dbpath: str diff --git a/llmstack/common/blocks/data/store/database/utils.py b/llmstack/common/blocks/data/store/database/utils.py new file mode 100644 index 00000000000..c631a2d8df0 --- /dev/null +++ b/llmstack/common/blocks/data/store/database/utils.py @@ -0,0 +1,83 @@ +from enum import StrEnum +from typing import List + +import sqlalchemy + +from llmstack.common.blocks.base.schema import BaseSchema +from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.mysql import ( + MySQLConfiguration, + get_mysql_ssl_config, +) +from llmstack.common.blocks.data.store.database.postgresql import ( + PostgresConfiguration, + get_pg_ssl_config, +) +from llmstack.common.blocks.data.store.database.sqlite import SQLiteConfiguration + + +class DatabaseEngineType(StrEnum): + POSTGRESQL = "postgresql" + MYSQL = "mysql" + SQLITE = "sqlite" + + +DATABASES = { + DatabaseEngineType.POSTGRESQL: { + "name": "PostgreSQL", + "driver": "postgresql+psycopg2", + }, + DatabaseEngineType.MYSQL: { + "name": "MySQL", + "driver": "mysql+mysqldb", + }, + DatabaseEngineType.SQLITE: { + "name": "SQLite", + "driver": "sqlite+pysqlite", + }, +} + +DatabaseConfiguration = MySQLConfiguration | PostgresConfiguration | SQLiteConfiguration + + +class DatabaseOutput(BaseSchema): + documents: List[DataDocument] + + +def get_database_connection( + configuration: DatabaseConfiguration, + ssl_config: dict = None, +) -> sqlalchemy.engine.Connection: + if configuration.engine not in DATABASES: + raise ValueError(f"Unsupported database engine: {configuration.type}") + + if not ssl_config: + if configuration.engine == DatabaseEngineType.POSTGRESQL: + ssl_config = get_pg_ssl_config(configuration.dict()) + elif configuration.engine == DatabaseEngineType.MYSQL: + ssl_config = get_mysql_ssl_config(configuration.dict()) + + database_name = configuration.dbpath if configuration.engine == DatabaseEngineType.SQLITE else configuration.dbname + + connect_args: dict = {} + + if ssl_config: + connect_args["ssl"] = ssl_config + + # Create URL + db_url = sqlalchemy.engine.URL.create( + drivername=DATABASES[configuration.engine]["driver"], + username=configuration.user if hasattr(configuration, "user") else None, + password=configuration.password if hasattr(configuration, "password") else None, + host=configuration.host if hasattr(configuration, "host") else None, + port=configuration.port if hasattr(configuration, "port") else None, + database=database_name, + ) + + # Create engine + engine = sqlalchemy.create_engine(db_url, connect_args=connect_args) + + # Connect to the database + connection = engine.connect() + + return connection diff --git a/llmstack/common/blocks/data/store/sqlite/read.py b/llmstack/common/blocks/data/store/sqlite/read.py deleted file mode 100644 index d774a04d0df..00000000000 --- a/llmstack/common/blocks/data/store/sqlite/read.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -import sqlite3 -from collections import defaultdict - -from llmstack.common.blocks.base.processor import ProcessorInterface -from llmstack.common.blocks.base.schema import BaseSchema -from llmstack.common.blocks.data import DataDocument -from llmstack.common.blocks.data.store.sqlite import SQLiteConfiguration, SQLiteOutput - - -class SQLiteReaderInput(BaseSchema): - sql: str - - -class SQLiteReader( - ProcessorInterface[SQLiteReaderInput, SQLiteOutput, SQLiteConfiguration], -): - def fetch_columns(self, columns): - column_names = set() - duplicates_counters = defaultdict(int) - new_columns = [] - - for col in columns: - column_name = col[0] - while column_name in column_names: - duplicates_counters[col[0]] += 1 - column_name = "{}{}".format( - col[0], - duplicates_counters[col[0]], - ) - - column_names.add(column_name) - new_columns.append( - {"name": column_name, "friendly_name": column_name, "type": col[1]}, - ) - - return new_columns - - def process( - self, - input: SQLiteReaderInput, - configuration: SQLiteConfiguration, - ) -> SQLiteOutput: - connection = None - try: - connection = sqlite3.connect(configuration.dbpath) - cursor = connection.cursor() - cursor.execute(input.sql) - - if cursor.description is not None: - columns = self.fetch_columns( - [(i[0], None) for i in cursor.description], - ) - rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor] - - data = {"columns": columns, "rows": rows} - json_data = json.dumps(data) - else: - raise Exception("Query completed but it returned no data.") - except Exception as e: - if connection: - connection.cancel() - raise e - finally: - if connection: - connection.close() - return SQLiteOutput( - documents=[ - DataDocument( - content=json_data, - content_text=json_data, - metadata={ - "mime_type": "application/json", - }, - ), - ], - ) diff --git a/llmstack/common/tests/blocks/data/store/postgres/__init__.py b/llmstack/common/tests/blocks/data/store/database/__init__.py similarity index 100% rename from llmstack/common/tests/blocks/data/store/postgres/__init__.py rename to llmstack/common/tests/blocks/data/store/database/__init__.py diff --git a/llmstack/common/tests/blocks/data/store/sqlite/sample.db b/llmstack/common/tests/blocks/data/store/database/sample.db similarity index 100% rename from llmstack/common/tests/blocks/data/store/sqlite/sample.db rename to llmstack/common/tests/blocks/data/store/database/sample.db diff --git a/llmstack/common/tests/blocks/data/store/database/test_database_reader.py b/llmstack/common/tests/blocks/data/store/database/test_database_reader.py new file mode 100644 index 00000000000..d6d3004a5ec --- /dev/null +++ b/llmstack/common/tests/blocks/data/store/database/test_database_reader.py @@ -0,0 +1,65 @@ +import unittest + +from llmstack.common.blocks.data.store.database.database_reader import ( + DatabaseReader, + DatabaseReaderInput, +) +from llmstack.common.blocks.data.store.database.mysql import MySQLConfiguration +from llmstack.common.blocks.data.store.database.postgresql import PostgresConfiguration +from llmstack.common.blocks.data.store.database.sqlite import SQLiteConfiguration + + +class MySQLReadTest(unittest.TestCase): + def test_read(self): + configuration = MySQLConfiguration( + user="root", + password="", + host="localhost", + port=5432, + dbname="usersdb", + ) + reader_input = DatabaseReaderInput( + sql="SELECT * FROM users", + ) + + response = DatabaseReader().process( + reader_input, + configuration, + ) + + self.assertEqual(len(response.documents), 1) + + +class PostgresReadTest(unittest.TestCase): + def test_read(self): + configuration = PostgresConfiguration( + user="root", + password="", + host="localhost", + port=5432, + dbname="usersdb", + ) + reader_input = DatabaseReaderInput( + sql="SELECT * FROM users", + ) + + response = DatabaseReader().process( + reader_input, + configuration, + ) + + self.assertEqual(len(response.documents), 1) + + +class SqliteReadTest(unittest.TestCase): + def test_read(self): + sample_db = f"{'/'.join((__file__.split('/')[:-1]))}/sample.db" + response = DatabaseReader().process( + DatabaseReaderInput( + sql="SELECT * FROM users", + ), + SQLiteConfiguration( + dbpath=sample_db, + ), + ) + self.assertEqual(len(response.documents), 1) diff --git a/llmstack/common/tests/blocks/data/store/sqlite/__init__.py b/llmstack/common/tests/blocks/data/store/sqlite/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/llmstack/common/tests/blocks/data/store/sqlite/test_read.py b/llmstack/common/tests/blocks/data/store/sqlite/test_read.py deleted file mode 100644 index f8929bf4ba1..00000000000 --- a/llmstack/common/tests/blocks/data/store/sqlite/test_read.py +++ /dev/null @@ -1,21 +0,0 @@ -import unittest - -from llmstack.common.blocks.data.store.sqlite import SQLiteConfiguration -from llmstack.common.blocks.data.store.sqlite.read import ( - SQLiteReader, - SQLiteReaderInput, -) - - -class SqliteReadTest(unittest.TestCase): - def test_read(self): - sample_db = f"{'/'.join((__file__.split('/')[:-1]))}/sample.db" - response = SQLiteReader().process( - SQLiteReaderInput( - sql="SELECT * FROM users", - ), - SQLiteConfiguration( - dbpath=sample_db, - ), - ) - self.assertEquals(len(response.documents), 1) From ad47adacde966ec9d4d85f9c38a5741b979b65d7 Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Thu, 14 Mar 2024 22:51:23 +0530 Subject: [PATCH 05/10] feat: add SQLDataSource handler to connect with sqlalchemy provided data sources or engines --- .../databases/{postgres.py => sql.py} | 85 +++++++++++-------- 1 file changed, 49 insertions(+), 36 deletions(-) rename llmstack/datasources/handlers/databases/{postgres.py => sql.py} (69%) diff --git a/llmstack/datasources/handlers/databases/postgres.py b/llmstack/datasources/handlers/databases/sql.py similarity index 69% rename from llmstack/datasources/handlers/databases/postgres.py rename to llmstack/datasources/handlers/databases/sql.py index 8e09ef31f2b..2815b25e3c5 100644 --- a/llmstack/datasources/handlers/databases/postgres.py +++ b/llmstack/datasources/handlers/databases/sql.py @@ -5,10 +5,14 @@ from pydantic import Field from llmstack.common.blocks.base.schema import BaseSchema as _Schema -from llmstack.common.blocks.data.store.postgres import PostgresConfiguration -from llmstack.common.blocks.data.store.postgres.read import ( - PostgresReader, - PostgresReaderInput, +from llmstack.common.blocks.data.store.database.database_reader import ( + DatabaseReader, + DatabaseReaderInput, +) +from llmstack.common.blocks.data.store.database.utils import ( + DATABASES, + DatabaseConfiguration, + DatabaseEngineType, ) from llmstack.common.blocks.data.store.vectorstore import Document from llmstack.common.utils.models import Config @@ -22,43 +26,54 @@ logger = logging.getLogger(__name__) -class PostgresConnection(_Schema): - host: str = Field(description="Host of the Postgres instance") +class SQLConnection(_Schema): + host: str = Field(description="Host of the Database instance") port: int = Field( - description="Port number to connect to the Postgres instance", + description="Port number to connect to the Database instance", ) - database_name: str = Field(description="Postgres database name") - username: str = Field(description="Postgres username") - password: Optional[str] = Field(description="Postgres password") + database_name: str = Field(description="Database name") + username: str = Field(description="Database username") + password: Optional[str] = Field(description="Database password") -class PostgresDatabaseSchema(DataSourceSchema): - connection: Optional[PostgresConnection] = Field( - description="Postgres connection details", +class SQLDatabaseSchema(DataSourceSchema): + connection: Optional[SQLConnection] = Field( + description="Database connection details", ) -class PostgresConnectionConfiguration(Config): - config_type = "postgres_connection" +class SQLConnectionConfiguration(Config): + engine: Optional[DatabaseEngineType] = None + config_type: Optional[str] = None is_encrypted = True - postgres_config: Optional[Dict] + config: Optional[Dict] + + def __init__(self, engine: DatabaseEngineType, *args, **kwargs): + super().__init__(**args, **kwargs) + self.engine = engine + self.config_type = f"{engine}_connection" -class PostgresDataSource(DataSourceProcessor[PostgresDatabaseSchema]): +class SQLDataSource(DataSourceProcessor[SQLDatabaseSchema]): # Initializer for the class. # It requires a datasource object as input, checks if it has a 'data' # configuration, and sets up Weaviate Database Configuration. def __init__(self, datasource: DataSource): self.datasource = datasource + + if self.datasource.type.slug not in DATABASES: + raise ValueError(f"Database engine {self.datasource.type.slug} not supported") + if self.datasource.config and "data" in self.datasource.config: - config_dict = PostgresConnectionConfiguration().from_dict( + config_dict = SQLConnectionConfiguration(engine=self.datasource.type.slug).from_dict( self.datasource.config, self.datasource.profile.decrypt_value, ) - self._configuration = PostgresDatabaseSchema( - **config_dict["postgres_config"], + self._configuration = SQLDatabaseSchema( + **config_dict["config"], ) - self._reader_configuration = PostgresConfiguration( + self._reader_configuration = DatabaseConfiguration( + engine=self.database_engine, user=self._configuration.connection.username, password=self._configuration.connection.password, host=self._configuration.connection.host, @@ -70,35 +85,33 @@ def __init__(self, datasource: DataSource): @staticmethod def name() -> str: - return "Postgres" + return "SQL" @staticmethod def slug() -> str: - return "postgres" + return "sql" @staticmethod def description() -> str: - return "Connect to a Postgres database" + return "Connect to a SQL Database" # This static method takes a dictionary for configuration and a DataSource object as inputs. # Validation of these inputs is performed and a dictionary containing the - # Postgres Connection Configuration is returned. + # Database Connection Configuration is returned. @staticmethod def process_validate_config( config_data: dict, datasource: DataSource, ) -> dict: - return PostgresConnectionConfiguration( - postgres_config=config_data, + return SQLConnectionConfiguration( + config=config_data, ).to_dict( encrypt_fn=datasource.profile.encrypt_value, ) - # This static method returns the provider slug for the datasource - # connector. @staticmethod def provider_slug() -> str: - return "postgres" + return "promptly" def validate_and_process(self, data: dict) -> List[DataSourceEntryItem]: raise NotImplementedError @@ -110,10 +123,10 @@ def add_entry(self, data: dict) -> Optional[DataSourceEntryItem]: raise NotImplementedError def similarity_search(self, query: str, **kwargs) -> List[dict]: - pg_client = PostgresReader() + client = DatabaseReader() result = ( - pg_client.process( - PostgresReaderInput( + client.process( + DatabaseReaderInput( sql=query, ), configuration=self._reader_configuration, @@ -148,10 +161,10 @@ def similarity_search(self, query: str, **kwargs) -> List[dict]: ] def hybrid_search(self, query: str, **kwargs) -> List[dict]: - pg_client = PostgresReader() + client = DatabaseReader() result = ( - pg_client.process( - PostgresReaderInput( + client.process( + DatabaseReaderInput( sql=query, ), configuration=self._reader_configuration, From 6a0a8419d17e192e546f51222508b656e08e4bd0 Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Fri, 15 Mar 2024 08:01:39 +0530 Subject: [PATCH 06/10] feat: make SQLConnection an union --- .../datasources/handlers/databases/sql.py | 55 +++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/llmstack/datasources/handlers/databases/sql.py b/llmstack/datasources/handlers/databases/sql.py index 2815b25e3c5..831f289a90a 100644 --- a/llmstack/datasources/handlers/databases/sql.py +++ b/llmstack/datasources/handlers/databases/sql.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from pydantic import Field @@ -26,14 +26,31 @@ logger = logging.getLogger(__name__) -class SQLConnection(_Schema): - host: str = Field(description="Host of the Database instance") +class PostgresConnection(_Schema): + host: str = Field(description="Host of the Postgres instance") port: int = Field( - description="Port number to connect to the Database instance", + description="Port number to connect to the Postgres instance", ) - database_name: str = Field(description="Database name") - username: str = Field(description="Database username") - password: Optional[str] = Field(description="Database password") + database_name: str = Field(description="Postgres database name") + username: str = Field(description="Postgres username") + password: Optional[str] = Field(description="Postgres password") + + +class MySQLConnection(_Schema): + host: str = Field(description="Host of the MySQL instance") + port: int = Field( + description="Port number to connect to the MySQL instance", + ) + database_name: str = Field(description="MySQL database name") + username: str = Field(description="MySQL username") + password: Optional[str] = Field(description="MySQL password") + + +class SQLiteConnection(_Schema): + database_path: str = Field(description="MySQL database name") + + +SQLConnection = Union[PostgresConnection, MySQLConnection, SQLiteConnection] class SQLDatabaseSchema(DataSourceSchema): @@ -72,15 +89,21 @@ def __init__(self, datasource: DataSource): self._configuration = SQLDatabaseSchema( **config_dict["config"], ) - self._reader_configuration = DatabaseConfiguration( - engine=self.database_engine, - user=self._configuration.connection.username, - password=self._configuration.connection.password, - host=self._configuration.connection.host, - port=self._configuration.connection.port, - dbname=self._configuration.connection.database_name, - use_ssl=False, - ) + if self.datasource.type.slug != DatabaseEngineType.SQLITE: + self._reader_configuration = DatabaseConfiguration( + engine=self.datasource.type.slug, + dbpath=self._configuration.connection.database_path, + ) + else: + self._reader_configuration = DatabaseConfiguration( + engine=self.datasource.type.slug, + user=self._configuration.connection.username, + password=self._configuration.connection.password, + host=self._configuration.connection.host, + port=self._configuration.connection.port, + dbname=self._configuration.connection.database_name, + use_ssl=False, + ) self._source_name = self.datasource.name @staticmethod From 0f315782460909211e9e08c56206b73dae7e53c8 Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Fri, 15 Mar 2024 08:04:42 +0530 Subject: [PATCH 07/10] refactor: change PostgresConnection to PostgreSQLConnection --- llmstack/datasources/handlers/databases/sql.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llmstack/datasources/handlers/databases/sql.py b/llmstack/datasources/handlers/databases/sql.py index 831f289a90a..de6b807f8f9 100644 --- a/llmstack/datasources/handlers/databases/sql.py +++ b/llmstack/datasources/handlers/databases/sql.py @@ -26,14 +26,14 @@ logger = logging.getLogger(__name__) -class PostgresConnection(_Schema): - host: str = Field(description="Host of the Postgres instance") +class PostgreSQLConnection(_Schema): + host: str = Field(description="Host of the PostgreSQL instance") port: int = Field( - description="Port number to connect to the Postgres instance", + description="Port number to connect to the PostgreSQL instance", ) - database_name: str = Field(description="Postgres database name") - username: str = Field(description="Postgres username") - password: Optional[str] = Field(description="Postgres password") + database_name: str = Field(description="PostgreSQL database name") + username: str = Field(description="PostgreSQL username") + password: Optional[str] = Field(description="PostgreSQL password") class MySQLConnection(_Schema): @@ -50,7 +50,7 @@ class SQLiteConnection(_Schema): database_path: str = Field(description="MySQL database name") -SQLConnection = Union[PostgresConnection, MySQLConnection, SQLiteConnection] +SQLConnection = Union[PostgreSQLConnection, MySQLConnection, SQLiteConnection] class SQLDatabaseSchema(DataSourceSchema): From 23ffb7df350017a989358002824608ec369470f8 Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Fri, 15 Mar 2024 17:25:40 +0530 Subject: [PATCH 08/10] feat: load DataSources on every request --- llmstack/datasources/apis.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/llmstack/datasources/apis.py b/llmstack/datasources/apis.py index 70b03983978..d01285acf57 100644 --- a/llmstack/datasources/apis.py +++ b/llmstack/datasources/apis.py @@ -33,16 +33,15 @@ class DataSourceTypeViewSet(viewsets.ModelViewSet): - queryset = DataSourceType.objects.all() serializer_class = DataSourceTypeSerializer + def get_queryset(self): + return DataSourceType.objects.all() + def get(self, request): - return DRFResponse( - DataSourceTypeSerializer( - instance=self.queryset, - many=True, - ).data, - ) + queryset = self.get_queryset() + serialzer = self.serializer_class(instance=queryset, many=True) + return DRFResponse(serialzer.data) class DataSourceEntryViewSet(viewsets.ModelViewSet): From 817b7545309273f58d1cb01f59804d1df8bde1e9 Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Fri, 15 Mar 2024 17:29:21 +0530 Subject: [PATCH 09/10] feat: allow UI to render SQL connections --- .../blocks/data/store/database/constants.py | 7 +++ .../data/store/database/database_reader.py | 3 +- .../blocks/data/store/database/mysql.py | 7 ++- .../blocks/data/store/database/postgresql.py | 6 +- .../blocks/data/store/database/sqlite.py | 7 ++- .../blocks/data/store/database/utils.py | 42 +++++++++----- .../datasources/handlers/databases/sql.py | 55 +++++++++++-------- 7 files changed, 81 insertions(+), 46 deletions(-) create mode 100644 llmstack/common/blocks/data/store/database/constants.py diff --git a/llmstack/common/blocks/data/store/database/constants.py b/llmstack/common/blocks/data/store/database/constants.py new file mode 100644 index 00000000000..985d3c57d60 --- /dev/null +++ b/llmstack/common/blocks/data/store/database/constants.py @@ -0,0 +1,7 @@ +from enum import StrEnum + + +class DatabaseEngineType(StrEnum): + POSTGRESQL = "postgresql" + MYSQL = "mysql" + SQLITE = "sqlite" diff --git a/llmstack/common/blocks/data/store/database/database_reader.py b/llmstack/common/blocks/data/store/database/database_reader.py index 2b3fd9946cd..4c278f61f1f 100644 --- a/llmstack/common/blocks/data/store/database/database_reader.py +++ b/llmstack/common/blocks/data/store/database/database_reader.py @@ -11,6 +11,7 @@ from llmstack.common.blocks.data import DataDocument from llmstack.common.blocks.data.store.database.utils import ( DatabaseConfiguration, + DatabaseConfigurationType, DatabaseOutput, get_database_connection, ) @@ -73,7 +74,7 @@ def fetch_columns(self, columns): def process( self, input: DatabaseReaderInput, - configuration: DatabaseConfiguration, + configuration: DatabaseConfigurationType, ) -> DatabaseOutput: connection = get_database_connection(configuration=configuration) try: diff --git a/llmstack/common/blocks/data/store/database/mysql.py b/llmstack/common/blocks/data/store/database/mysql.py index b3f4cd7be31..83c3bd602b8 100644 --- a/llmstack/common/blocks/data/store/database/mysql.py +++ b/llmstack/common/blocks/data/store/database/mysql.py @@ -1,8 +1,11 @@ from enum import Enum -from typing import ClassVar, List, Optional +from typing import List, Optional + +from typing_extensions import Literal from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType try: import MySQLdb @@ -21,7 +24,7 @@ class SSLMode(str, Enum): class MySQLConfiguration(BaseSchema): - engine: ClassVar[str] = "mysql" + engine: Literal[DatabaseEngineType.MYSQL] = DatabaseEngineType.MYSQL user: Optional[str] password: Optional[str] host: str = "127.0.0.1" diff --git a/llmstack/common/blocks/data/store/database/postgresql.py b/llmstack/common/blocks/data/store/database/postgresql.py index 93988e65340..30a39670be3 100644 --- a/llmstack/common/blocks/data/store/database/postgresql.py +++ b/llmstack/common/blocks/data/store/database/postgresql.py @@ -1,12 +1,14 @@ from base64 import b64decode from enum import Enum from tempfile import NamedTemporaryFile -from typing import ClassVar, List, Optional +from typing import List, Optional import psycopg2 +from typing_extensions import Literal from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType class SSLMode(str, Enum): @@ -19,7 +21,7 @@ class SSLMode(str, Enum): class PostgresConfiguration(BaseSchema): - engine: ClassVar[str] = "postgresql" + engine: Literal[DatabaseEngineType.POSTGRESQL] = DatabaseEngineType.POSTGRESQL user: Optional[str] password: Optional[str] host: str = "127.0.0.1" diff --git a/llmstack/common/blocks/data/store/database/sqlite.py b/llmstack/common/blocks/data/store/database/sqlite.py index c190c9a0b55..c4722760b50 100644 --- a/llmstack/common/blocks/data/store/database/sqlite.py +++ b/llmstack/common/blocks/data/store/database/sqlite.py @@ -1,11 +1,14 @@ -from typing import ClassVar, List +from typing import List + +from typing_extensions import Literal from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType class SQLiteConfiguration(BaseSchema): - engine: ClassVar[str] = "sqlite" + engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE dbpath: str diff --git a/llmstack/common/blocks/data/store/database/utils.py b/llmstack/common/blocks/data/store/database/utils.py index c631a2d8df0..8259598f503 100644 --- a/llmstack/common/blocks/data/store/database/utils.py +++ b/llmstack/common/blocks/data/store/database/utils.py @@ -1,10 +1,10 @@ -from enum import StrEnum -from typing import List +from typing import List, TypeVar import sqlalchemy from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType from llmstack.common.blocks.data.store.database.mysql import ( MySQLConfiguration, get_mysql_ssl_config, @@ -15,13 +15,6 @@ ) from llmstack.common.blocks.data.store.database.sqlite import SQLiteConfiguration - -class DatabaseEngineType(StrEnum): - POSTGRESQL = "postgresql" - MYSQL = "mysql" - SQLITE = "sqlite" - - DATABASES = { DatabaseEngineType.POSTGRESQL: { "name": "PostgreSQL", @@ -39,23 +32,42 @@ class DatabaseEngineType(StrEnum): DatabaseConfiguration = MySQLConfiguration | PostgresConfiguration | SQLiteConfiguration +DatabaseConfigurationType = TypeVar("DatabaseConfigurationType", bound=DatabaseConfiguration) + class DatabaseOutput(BaseSchema): documents: List[DataDocument] +def get_database_configuration_class(engine: DatabaseEngineType) -> DatabaseConfigurationType: + if engine == DatabaseEngineType.POSTGRESQL: + return PostgresConfiguration + elif engine == DatabaseEngineType.MYSQL: + return MySQLConfiguration + elif engine == DatabaseEngineType.SQLITE: + return SQLiteConfiguration + else: + raise ValueError(f"Unsupported database engine: {engine}") + + +def get_ssl_config(configuration: DatabaseConfigurationType) -> dict: + ssl_config = {} + if configuration.engine == DatabaseEngineType.POSTGRESQL: + ssl_config = get_pg_ssl_config(configuration.dict()) + elif configuration.engine == DatabaseEngineType.MYSQL: + ssl_config = get_mysql_ssl_config(configuration.dict()) + return ssl_config + + def get_database_connection( - configuration: DatabaseConfiguration, + configuration: DatabaseConfigurationType, ssl_config: dict = None, ) -> sqlalchemy.engine.Connection: if configuration.engine not in DATABASES: - raise ValueError(f"Unsupported database engine: {configuration.type}") + raise ValueError(f"Unsupported database engine: {configuration.engine}") if not ssl_config: - if configuration.engine == DatabaseEngineType.POSTGRESQL: - ssl_config = get_pg_ssl_config(configuration.dict()) - elif configuration.engine == DatabaseEngineType.MYSQL: - ssl_config = get_mysql_ssl_config(configuration.dict()) + ssl_config = get_ssl_config(configuration) database_name = configuration.dbpath if configuration.engine == DatabaseEngineType.SQLITE else configuration.dbname diff --git a/llmstack/datasources/handlers/databases/sql.py b/llmstack/datasources/handlers/databases/sql.py index de6b807f8f9..4f89036da60 100644 --- a/llmstack/datasources/handlers/databases/sql.py +++ b/llmstack/datasources/handlers/databases/sql.py @@ -1,6 +1,6 @@ import json import logging -from typing import Dict, List, Optional, Union +from typing import ClassVar, Dict, List, Optional, Union from pydantic import Field @@ -10,9 +10,8 @@ DatabaseReaderInput, ) from llmstack.common.blocks.data.store.database.utils import ( - DATABASES, - DatabaseConfiguration, DatabaseEngineType, + get_database_configuration_class, ) from llmstack.common.blocks.data.store.vectorstore import Document from llmstack.common.utils.models import Config @@ -27,49 +26,56 @@ class PostgreSQLConnection(_Schema): + engine: ClassVar[str] = DatabaseEngineType.POSTGRESQL host: str = Field(description="Host of the PostgreSQL instance") port: int = Field( description="Port number to connect to the PostgreSQL instance", ) database_name: str = Field(description="PostgreSQL database name") - username: str = Field(description="PostgreSQL username") - password: Optional[str] = Field(description="PostgreSQL password") + username: str = Field(description="PostgreSQL database username") + password: Optional[str] = Field(description="PostgreSQL database password") + + class Config: + title = "PostgreSQL" class MySQLConnection(_Schema): + engine: ClassVar[str] = DatabaseEngineType.MYSQL host: str = Field(description="Host of the MySQL instance") port: int = Field( description="Port number to connect to the MySQL instance", ) database_name: str = Field(description="MySQL database name") - username: str = Field(description="MySQL username") - password: Optional[str] = Field(description="MySQL password") + username: str = Field(description="MySQL database username") + password: Optional[str] = Field(description="MySQL database password") + + class Config: + title = "MySQL" class SQLiteConnection(_Schema): + engine: ClassVar[str] = DatabaseEngineType.SQLITE database_path: str = Field(description="MySQL database name") + class Config: + title = "SQLite" + SQLConnection = Union[PostgreSQLConnection, MySQLConnection, SQLiteConnection] class SQLDatabaseSchema(DataSourceSchema): connection: Optional[SQLConnection] = Field( - description="Database connection details", + title="Database", + description="Database details", ) class SQLConnectionConfiguration(Config): - engine: Optional[DatabaseEngineType] = None - config_type: Optional[str] = None + config_type: Optional[str] = "sql_connection" is_encrypted = True config: Optional[Dict] - def __init__(self, engine: DatabaseEngineType, *args, **kwargs): - super().__init__(**args, **kwargs) - self.engine = engine - self.config_type = f"{engine}_connection" - class SQLDataSource(DataSourceProcessor[SQLDatabaseSchema]): # Initializer for the class. @@ -78,25 +84,26 @@ class SQLDataSource(DataSourceProcessor[SQLDatabaseSchema]): def __init__(self, datasource: DataSource): self.datasource = datasource - if self.datasource.type.slug not in DATABASES: - raise ValueError(f"Database engine {self.datasource.type.slug} not supported") - if self.datasource.config and "data" in self.datasource.config: - config_dict = SQLConnectionConfiguration(engine=self.datasource.type.slug).from_dict( + config_dict = SQLConnectionConfiguration().from_dict( self.datasource.config, self.datasource.profile.decrypt_value, ) + self._configuration = SQLDatabaseSchema( **config_dict["config"], ) - if self.datasource.type.slug != DatabaseEngineType.SQLITE: - self._reader_configuration = DatabaseConfiguration( - engine=self.datasource.type.slug, + + database_configuration_class = get_database_configuration_class(self._configuration.connection.engine) + + if self._configuration.connection.engine == DatabaseEngineType.SQLITE: + self._reader_configuration = database_configuration_class( + engine=self._configuration.connection.engine, dbpath=self._configuration.connection.database_path, ) else: - self._reader_configuration = DatabaseConfiguration( - engine=self.datasource.type.slug, + self._reader_configuration = database_configuration_class( + engine=self._configuration.connection.engine, user=self._configuration.connection.username, password=self._configuration.connection.password, host=self._configuration.connection.host, From 32a29a64d2c00b810543610161cbb5e6c8961e0b Mon Sep 17 00:00:00 2001 From: Bala Subrahmanyam Varanasi Date: Fri, 15 Mar 2024 17:55:13 +0530 Subject: [PATCH 10/10] feat: change engine from classvar to literal in connection schemas --- llmstack/datasources/handlers/databases/sql.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/llmstack/datasources/handlers/databases/sql.py b/llmstack/datasources/handlers/databases/sql.py index 4f89036da60..812b418f640 100644 --- a/llmstack/datasources/handlers/databases/sql.py +++ b/llmstack/datasources/handlers/databases/sql.py @@ -1,8 +1,9 @@ import json import logging -from typing import ClassVar, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union from pydantic import Field +from typing_extensions import Literal from llmstack.common.blocks.base.schema import BaseSchema as _Schema from llmstack.common.blocks.data.store.database.database_reader import ( @@ -26,7 +27,7 @@ class PostgreSQLConnection(_Schema): - engine: ClassVar[str] = DatabaseEngineType.POSTGRESQL + engine: Literal[DatabaseEngineType.POSTGRESQL] = DatabaseEngineType.POSTGRESQL host: str = Field(description="Host of the PostgreSQL instance") port: int = Field( description="Port number to connect to the PostgreSQL instance", @@ -40,7 +41,7 @@ class Config: class MySQLConnection(_Schema): - engine: ClassVar[str] = DatabaseEngineType.MYSQL + engine: Literal[DatabaseEngineType.MYSQL] = DatabaseEngineType.MYSQL host: str = Field(description="Host of the MySQL instance") port: int = Field( description="Port number to connect to the MySQL instance", @@ -54,7 +55,7 @@ class Config: class SQLiteConnection(_Schema): - engine: ClassVar[str] = DatabaseEngineType.SQLITE + engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE database_path: str = Field(description="MySQL database name") class Config: