diff --git a/.vscode/settings.json b/.vscode/settings.json index 08bb3f6cf76..15a08cbb35e 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,7 +4,7 @@ "editor.tabSize": 4, "editor.defaultFormatter": "ms-python.black-formatter", "editor.codeActionsOnSave": { - "source.organizeImports": true + "source.organizeImports": "explicit" } }, "isort.args": ["--profile", "black"], diff --git a/llmstack/common/blocks/data/store/database/constants.py b/llmstack/common/blocks/data/store/database/constants.py new file mode 100644 index 00000000000..985d3c57d60 --- /dev/null +++ b/llmstack/common/blocks/data/store/database/constants.py @@ -0,0 +1,7 @@ +from enum import StrEnum + + +class DatabaseEngineType(StrEnum): + POSTGRESQL = "postgresql" + MYSQL = "mysql" + SQLITE = "sqlite" diff --git a/llmstack/common/blocks/data/store/postgres/read.py b/llmstack/common/blocks/data/store/database/database_reader.py similarity index 55% rename from llmstack/common/blocks/data/store/postgres/read.py rename to llmstack/common/blocks/data/store/database/database_reader.py index 93e486db4da..4c278f61f1f 100644 --- a/llmstack/common/blocks/data/store/postgres/read.py +++ b/llmstack/common/blocks/data/store/database/database_reader.py @@ -1,50 +1,23 @@ +import collections +import datetime import json -from collections import defaultdict -from datetime import datetime -from uuid import UUID +import uuid +import sqlalchemy from psycopg2.extras import Range from llmstack.common.blocks.base.processor import ProcessorInterface from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument -from llmstack.common.blocks.data.store.postgres import ( - PostgresConfiguration, - PostgresOutput, - get_pg_connection, +from llmstack.common.blocks.data.store.database.utils import ( + DatabaseConfiguration, + DatabaseConfigurationType, + DatabaseOutput, + get_database_connection, ) -class PostgresReaderInput(BaseSchema): - sql: str - - -types_map = { - 20: "integer", - 21: "integer", - 23: "integer", - 700: "float", - 1700: "float", - 701: "float", - 16: "boolean", - 1082: "date", - 1182: "date", - 1114: "datetime", - 1184: "datetime", - 1115: "datetime", - 1185: "datetime", - 1014: "string", - 1015: "string", - 1008: "string", - 1009: "string", - 2951: "string", - 1043: "string", - 1002: "string", - 1003: "string", -} - - -class PostgreSQLJSONEncoder(json.JSONEncoder): +class DatabaseJSONEncoder(json.JSONEncoder): def default(self, o): if isinstance(o, Range): # From: https://github.com/psycopg/psycopg2/pull/779 @@ -64,20 +37,24 @@ def default(self, o): ] return "".join(items) - elif isinstance(o, UUID): + elif isinstance(o, uuid.UUID): return str(o.hex) - elif isinstance(o, datetime): + elif isinstance(o, (datetime.date, datetime.datetime)): return o.isoformat() - return super(PostgreSQLJSONEncoder, self).default(o) + return super().default(o) -class PostgresReader( - ProcessorInterface[PostgresReaderInput, PostgresOutput, PostgresConfiguration], +class DatabaseReaderInput(BaseSchema): + sql: str + + +class DatabaseReader( + ProcessorInterface[DatabaseReaderInput, DatabaseOutput, DatabaseConfiguration], ): def fetch_columns(self, columns): column_names = set() - duplicates_counters = defaultdict(int) + duplicates_counters = collections.defaultdict(int) new_columns = [] for col in columns: @@ -90,35 +67,35 @@ def fetch_columns(self, columns): ) column_names.add(column_name) - new_columns.append( - {"name": column_name, "friendly_name": column_name, "type": col[1]}, - ) + new_columns.append({"name": column_name, "type": col[1]}) return new_columns def process( self, - input: PostgresReaderInput, - configuration: PostgresConfiguration, - ) -> PostgresOutput: - connection = get_pg_connection(configuration.dict()) - cursor = connection.cursor() + input: DatabaseReaderInput, + configuration: DatabaseConfigurationType, + ) -> DatabaseOutput: + connection = get_database_connection(configuration=configuration) try: - cursor.execute(input.sql) + result = connection.execute(sqlalchemy.text(input.sql)) + cursor = result.cursor + if cursor.description is not None: columns = self.fetch_columns( - [(i[0], types_map.get(i[1], None)) for i in cursor.description], + [(i[0], None) for i in cursor.description], ) rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor] data = {"columns": columns, "rows": rows} - json_data = json.dumps(data, cls=PostgreSQLJSONEncoder) + json_data = json.dumps(data, cls=DatabaseJSONEncoder) else: raise Exception("Query completed but it returned no data.") except Exception as e: - connection.cancel() + connection.close() + connection.engine.dispose() raise e - return PostgresOutput( + return DatabaseOutput( documents=[ DataDocument( content=json_data, diff --git a/llmstack/common/blocks/data/store/database/mysql.py b/llmstack/common/blocks/data/store/database/mysql.py new file mode 100644 index 00000000000..83c3bd602b8 --- /dev/null +++ b/llmstack/common/blocks/data/store/database/mysql.py @@ -0,0 +1,88 @@ +from enum import Enum +from typing import List, Optional + +from typing_extensions import Literal + +from llmstack.common.blocks.base.schema import BaseSchema +from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType + +try: + import MySQLdb + + enabled = True +except ImportError: + enabled = False + + +class SSLMode(str, Enum): + disabled = "DISABLED" + preferred = "PREFERRED" + required = "REQUIRED" + verify_ca = "VERIFY_CA" + verify_identity = "VERIFY_IDENTITY" + + +class MySQLConfiguration(BaseSchema): + engine: Literal[DatabaseEngineType.MYSQL] = DatabaseEngineType.MYSQL + user: Optional[str] + password: Optional[str] + host: str = "127.0.0.1" + port: int = 3306 + dbname: str + use_ssl: bool = False + sslmode: SSLMode = "preferred" + ssl_ca: Optional[str] = None + ssl_cert: Optional[str] = None + ssl_key: Optional[str] = None + + class Config: + schema_extra = { + "order": ["host", "port", "user", "password"], + "required": ["dbname"], + "secret": ["password", "ssl_ca", "ssl_cert", "ssl_key"], + "extra_options": ["sslmode", "ssl_ca", "ssl_cert", "ssl_key"], + } + + +class MySQLOutput(BaseSchema): + documents: List[DataDocument] + + +def get_mysql_ssl_config(configuration: dict): + if not configuration.get("use_ssl"): + return {} + + ssl_config = {"sslmode": configuration.get("sslmode", "prefer")} + + if configuration.get("use_ssl"): + config_map = {"ssl_mode": "preferred", "ssl_cacert": "ca", "ssl_cert": "cert", "ssl_key": "key"} + for key, cfg in config_map.items(): + val = configuration.get(key) + if val: + ssl_config[cfg] = val + + return ssl_config + + +def get_mysql_connection(configuration: dict): + params = dict( + host=configuration.get("host"), + user=configuration.get("user"), + passwd=configuration.get("password"), + db=configuration.get("dbname"), + port=configuration.get("port", 3306), + charset=configuration.get("charset", "utf8"), + use_unicode=configuration.get("use_unicode", True), + connect_timeout=configuration.get("connect_timeout", 60), + autocommit=configuration.get("autocommit", True), + ) + + ssl_options = get_mysql_ssl_config() + + if ssl_options: + params["ssl"] = ssl_options + + connection = MySQLdb.connect(**params) + + return connection diff --git a/llmstack/common/blocks/data/store/postgres/__init__.py b/llmstack/common/blocks/data/store/database/postgresql.py similarity index 86% rename from llmstack/common/blocks/data/store/postgres/__init__.py rename to llmstack/common/blocks/data/store/database/postgresql.py index d8a322169d9..30a39670be3 100644 --- a/llmstack/common/blocks/data/store/postgres/__init__.py +++ b/llmstack/common/blocks/data/store/database/postgresql.py @@ -4,9 +4,11 @@ from typing import List, Optional import psycopg2 +from typing_extensions import Literal from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType class SSLMode(str, Enum): @@ -19,6 +21,7 @@ class SSLMode(str, Enum): class PostgresConfiguration(BaseSchema): + engine: Literal[DatabaseEngineType.POSTGRESQL] = DatabaseEngineType.POSTGRESQL user: Optional[str] password: Optional[str] host: str = "127.0.0.1" @@ -53,7 +56,10 @@ def _create_cert_file(configuration, key, ssl_config): ssl_config[key] = cert_file.name -def _get_ssl_config(configuration: dict): +def get_pg_ssl_config(configuration: dict): + if not configuration.get("use_ssl"): + return {} + ssl_config = {"sslmode": configuration.get("sslmode", "prefer")} _create_cert_file(configuration, "sslrootcert", ssl_config) @@ -65,7 +71,7 @@ def _get_ssl_config(configuration: dict): def get_pg_connection(configuration: dict): ssl_config = ( - _get_ssl_config( + get_pg_ssl_config( configuration, ) if configuration.get("use_ssl") diff --git a/llmstack/common/blocks/data/store/sqlite/__init__.py b/llmstack/common/blocks/data/store/database/sqlite.py similarity index 56% rename from llmstack/common/blocks/data/store/sqlite/__init__.py rename to llmstack/common/blocks/data/store/database/sqlite.py index 324fa2ea71e..c4722760b50 100644 --- a/llmstack/common/blocks/data/store/sqlite/__init__.py +++ b/llmstack/common/blocks/data/store/database/sqlite.py @@ -1,10 +1,14 @@ from typing import List +from typing_extensions import Literal + from llmstack.common.blocks.base.schema import BaseSchema from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType class SQLiteConfiguration(BaseSchema): + engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE dbpath: str diff --git a/llmstack/common/blocks/data/store/database/utils.py b/llmstack/common/blocks/data/store/database/utils.py new file mode 100644 index 00000000000..8259598f503 --- /dev/null +++ b/llmstack/common/blocks/data/store/database/utils.py @@ -0,0 +1,95 @@ +from typing import List, TypeVar + +import sqlalchemy + +from llmstack.common.blocks.base.schema import BaseSchema +from llmstack.common.blocks.data import DataDocument +from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType +from llmstack.common.blocks.data.store.database.mysql import ( + MySQLConfiguration, + get_mysql_ssl_config, +) +from llmstack.common.blocks.data.store.database.postgresql import ( + PostgresConfiguration, + get_pg_ssl_config, +) +from llmstack.common.blocks.data.store.database.sqlite import SQLiteConfiguration + +DATABASES = { + DatabaseEngineType.POSTGRESQL: { + "name": "PostgreSQL", + "driver": "postgresql+psycopg2", + }, + DatabaseEngineType.MYSQL: { + "name": "MySQL", + "driver": "mysql+mysqldb", + }, + DatabaseEngineType.SQLITE: { + "name": "SQLite", + "driver": "sqlite+pysqlite", + }, +} + +DatabaseConfiguration = MySQLConfiguration | PostgresConfiguration | SQLiteConfiguration + +DatabaseConfigurationType = TypeVar("DatabaseConfigurationType", bound=DatabaseConfiguration) + + +class DatabaseOutput(BaseSchema): + documents: List[DataDocument] + + +def get_database_configuration_class(engine: DatabaseEngineType) -> DatabaseConfigurationType: + if engine == DatabaseEngineType.POSTGRESQL: + return PostgresConfiguration + elif engine == DatabaseEngineType.MYSQL: + return MySQLConfiguration + elif engine == DatabaseEngineType.SQLITE: + return SQLiteConfiguration + else: + raise ValueError(f"Unsupported database engine: {engine}") + + +def get_ssl_config(configuration: DatabaseConfigurationType) -> dict: + ssl_config = {} + if configuration.engine == DatabaseEngineType.POSTGRESQL: + ssl_config = get_pg_ssl_config(configuration.dict()) + elif configuration.engine == DatabaseEngineType.MYSQL: + ssl_config = get_mysql_ssl_config(configuration.dict()) + return ssl_config + + +def get_database_connection( + configuration: DatabaseConfigurationType, + ssl_config: dict = None, +) -> sqlalchemy.engine.Connection: + if configuration.engine not in DATABASES: + raise ValueError(f"Unsupported database engine: {configuration.engine}") + + if not ssl_config: + ssl_config = get_ssl_config(configuration) + + database_name = configuration.dbpath if configuration.engine == DatabaseEngineType.SQLITE else configuration.dbname + + connect_args: dict = {} + + if ssl_config: + connect_args["ssl"] = ssl_config + + # Create URL + db_url = sqlalchemy.engine.URL.create( + drivername=DATABASES[configuration.engine]["driver"], + username=configuration.user if hasattr(configuration, "user") else None, + password=configuration.password if hasattr(configuration, "password") else None, + host=configuration.host if hasattr(configuration, "host") else None, + port=configuration.port if hasattr(configuration, "port") else None, + database=database_name, + ) + + # Create engine + engine = sqlalchemy.create_engine(db_url, connect_args=connect_args) + + # Connect to the database + connection = engine.connect() + + return connection diff --git a/llmstack/common/blocks/data/store/sqlite/read.py b/llmstack/common/blocks/data/store/sqlite/read.py deleted file mode 100644 index d774a04d0df..00000000000 --- a/llmstack/common/blocks/data/store/sqlite/read.py +++ /dev/null @@ -1,77 +0,0 @@ -import json -import sqlite3 -from collections import defaultdict - -from llmstack.common.blocks.base.processor import ProcessorInterface -from llmstack.common.blocks.base.schema import BaseSchema -from llmstack.common.blocks.data import DataDocument -from llmstack.common.blocks.data.store.sqlite import SQLiteConfiguration, SQLiteOutput - - -class SQLiteReaderInput(BaseSchema): - sql: str - - -class SQLiteReader( - ProcessorInterface[SQLiteReaderInput, SQLiteOutput, SQLiteConfiguration], -): - def fetch_columns(self, columns): - column_names = set() - duplicates_counters = defaultdict(int) - new_columns = [] - - for col in columns: - column_name = col[0] - while column_name in column_names: - duplicates_counters[col[0]] += 1 - column_name = "{}{}".format( - col[0], - duplicates_counters[col[0]], - ) - - column_names.add(column_name) - new_columns.append( - {"name": column_name, "friendly_name": column_name, "type": col[1]}, - ) - - return new_columns - - def process( - self, - input: SQLiteReaderInput, - configuration: SQLiteConfiguration, - ) -> SQLiteOutput: - connection = None - try: - connection = sqlite3.connect(configuration.dbpath) - cursor = connection.cursor() - cursor.execute(input.sql) - - if cursor.description is not None: - columns = self.fetch_columns( - [(i[0], None) for i in cursor.description], - ) - rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor] - - data = {"columns": columns, "rows": rows} - json_data = json.dumps(data) - else: - raise Exception("Query completed but it returned no data.") - except Exception as e: - if connection: - connection.cancel() - raise e - finally: - if connection: - connection.close() - return SQLiteOutput( - documents=[ - DataDocument( - content=json_data, - content_text=json_data, - metadata={ - "mime_type": "application/json", - }, - ), - ], - ) diff --git a/llmstack/common/tests/blocks/data/store/postgres/__init__.py b/llmstack/common/tests/blocks/data/store/database/__init__.py similarity index 100% rename from llmstack/common/tests/blocks/data/store/postgres/__init__.py rename to llmstack/common/tests/blocks/data/store/database/__init__.py diff --git a/llmstack/common/tests/blocks/data/store/sqlite/sample.db b/llmstack/common/tests/blocks/data/store/database/sample.db similarity index 100% rename from llmstack/common/tests/blocks/data/store/sqlite/sample.db rename to llmstack/common/tests/blocks/data/store/database/sample.db diff --git a/llmstack/common/tests/blocks/data/store/database/test_database_reader.py b/llmstack/common/tests/blocks/data/store/database/test_database_reader.py new file mode 100644 index 00000000000..d6d3004a5ec --- /dev/null +++ b/llmstack/common/tests/blocks/data/store/database/test_database_reader.py @@ -0,0 +1,65 @@ +import unittest + +from llmstack.common.blocks.data.store.database.database_reader import ( + DatabaseReader, + DatabaseReaderInput, +) +from llmstack.common.blocks.data.store.database.mysql import MySQLConfiguration +from llmstack.common.blocks.data.store.database.postgresql import PostgresConfiguration +from llmstack.common.blocks.data.store.database.sqlite import SQLiteConfiguration + + +class MySQLReadTest(unittest.TestCase): + def test_read(self): + configuration = MySQLConfiguration( + user="root", + password="", + host="localhost", + port=5432, + dbname="usersdb", + ) + reader_input = DatabaseReaderInput( + sql="SELECT * FROM users", + ) + + response = DatabaseReader().process( + reader_input, + configuration, + ) + + self.assertEqual(len(response.documents), 1) + + +class PostgresReadTest(unittest.TestCase): + def test_read(self): + configuration = PostgresConfiguration( + user="root", + password="", + host="localhost", + port=5432, + dbname="usersdb", + ) + reader_input = DatabaseReaderInput( + sql="SELECT * FROM users", + ) + + response = DatabaseReader().process( + reader_input, + configuration, + ) + + self.assertEqual(len(response.documents), 1) + + +class SqliteReadTest(unittest.TestCase): + def test_read(self): + sample_db = f"{'/'.join((__file__.split('/')[:-1]))}/sample.db" + response = DatabaseReader().process( + DatabaseReaderInput( + sql="SELECT * FROM users", + ), + SQLiteConfiguration( + dbpath=sample_db, + ), + ) + self.assertEqual(len(response.documents), 1) diff --git a/llmstack/common/tests/blocks/data/store/sqlite/__init__.py b/llmstack/common/tests/blocks/data/store/sqlite/__init__.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/llmstack/common/tests/blocks/data/store/sqlite/test_read.py b/llmstack/common/tests/blocks/data/store/sqlite/test_read.py deleted file mode 100644 index f8929bf4ba1..00000000000 --- a/llmstack/common/tests/blocks/data/store/sqlite/test_read.py +++ /dev/null @@ -1,21 +0,0 @@ -import unittest - -from llmstack.common.blocks.data.store.sqlite import SQLiteConfiguration -from llmstack.common.blocks.data.store.sqlite.read import ( - SQLiteReader, - SQLiteReaderInput, -) - - -class SqliteReadTest(unittest.TestCase): - def test_read(self): - sample_db = f"{'/'.join((__file__.split('/')[:-1]))}/sample.db" - response = SQLiteReader().process( - SQLiteReaderInput( - sql="SELECT * FROM users", - ), - SQLiteConfiguration( - dbpath=sample_db, - ), - ) - self.assertEquals(len(response.documents), 1) diff --git a/llmstack/datasources/apis.py b/llmstack/datasources/apis.py index 70b03983978..d01285acf57 100644 --- a/llmstack/datasources/apis.py +++ b/llmstack/datasources/apis.py @@ -33,16 +33,15 @@ class DataSourceTypeViewSet(viewsets.ModelViewSet): - queryset = DataSourceType.objects.all() serializer_class = DataSourceTypeSerializer + def get_queryset(self): + return DataSourceType.objects.all() + def get(self, request): - return DRFResponse( - DataSourceTypeSerializer( - instance=self.queryset, - many=True, - ).data, - ) + queryset = self.get_queryset() + serialzer = self.serializer_class(instance=queryset, many=True) + return DRFResponse(serialzer.data) class DataSourceEntryViewSet(viewsets.ModelViewSet): diff --git a/llmstack/datasources/handlers/databases/postgres.py b/llmstack/datasources/handlers/databases/sql.py similarity index 54% rename from llmstack/datasources/handlers/databases/postgres.py rename to llmstack/datasources/handlers/databases/sql.py index 8e09ef31f2b..812b418f640 100644 --- a/llmstack/datasources/handlers/databases/postgres.py +++ b/llmstack/datasources/handlers/databases/sql.py @@ -1,14 +1,18 @@ import json import logging -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Union from pydantic import Field +from typing_extensions import Literal from llmstack.common.blocks.base.schema import BaseSchema as _Schema -from llmstack.common.blocks.data.store.postgres import PostgresConfiguration -from llmstack.common.blocks.data.store.postgres.read import ( - PostgresReader, - PostgresReaderInput, +from llmstack.common.blocks.data.store.database.database_reader import ( + DatabaseReader, + DatabaseReaderInput, +) +from llmstack.common.blocks.data.store.database.utils import ( + DatabaseEngineType, + get_database_configuration_class, ) from llmstack.common.blocks.data.store.vectorstore import Document from llmstack.common.utils.models import Config @@ -22,83 +26,123 @@ logger = logging.getLogger(__name__) -class PostgresConnection(_Schema): - host: str = Field(description="Host of the Postgres instance") +class PostgreSQLConnection(_Schema): + engine: Literal[DatabaseEngineType.POSTGRESQL] = DatabaseEngineType.POSTGRESQL + host: str = Field(description="Host of the PostgreSQL instance") port: int = Field( - description="Port number to connect to the Postgres instance", + description="Port number to connect to the PostgreSQL instance", ) - database_name: str = Field(description="Postgres database name") - username: str = Field(description="Postgres username") - password: Optional[str] = Field(description="Postgres password") + database_name: str = Field(description="PostgreSQL database name") + username: str = Field(description="PostgreSQL database username") + password: Optional[str] = Field(description="PostgreSQL database password") + + class Config: + title = "PostgreSQL" -class PostgresDatabaseSchema(DataSourceSchema): - connection: Optional[PostgresConnection] = Field( - description="Postgres connection details", +class MySQLConnection(_Schema): + engine: Literal[DatabaseEngineType.MYSQL] = DatabaseEngineType.MYSQL + host: str = Field(description="Host of the MySQL instance") + port: int = Field( + description="Port number to connect to the MySQL instance", ) + database_name: str = Field(description="MySQL database name") + username: str = Field(description="MySQL database username") + password: Optional[str] = Field(description="MySQL database password") + + class Config: + title = "MySQL" + +class SQLiteConnection(_Schema): + engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE + database_path: str = Field(description="MySQL database name") -class PostgresConnectionConfiguration(Config): - config_type = "postgres_connection" + class Config: + title = "SQLite" + + +SQLConnection = Union[PostgreSQLConnection, MySQLConnection, SQLiteConnection] + + +class SQLDatabaseSchema(DataSourceSchema): + connection: Optional[SQLConnection] = Field( + title="Database", + description="Database details", + ) + + +class SQLConnectionConfiguration(Config): + config_type: Optional[str] = "sql_connection" is_encrypted = True - postgres_config: Optional[Dict] + config: Optional[Dict] -class PostgresDataSource(DataSourceProcessor[PostgresDatabaseSchema]): +class SQLDataSource(DataSourceProcessor[SQLDatabaseSchema]): # Initializer for the class. # It requires a datasource object as input, checks if it has a 'data' # configuration, and sets up Weaviate Database Configuration. def __init__(self, datasource: DataSource): self.datasource = datasource + if self.datasource.config and "data" in self.datasource.config: - config_dict = PostgresConnectionConfiguration().from_dict( + config_dict = SQLConnectionConfiguration().from_dict( self.datasource.config, self.datasource.profile.decrypt_value, ) - self._configuration = PostgresDatabaseSchema( - **config_dict["postgres_config"], - ) - self._reader_configuration = PostgresConfiguration( - user=self._configuration.connection.username, - password=self._configuration.connection.password, - host=self._configuration.connection.host, - port=self._configuration.connection.port, - dbname=self._configuration.connection.database_name, - use_ssl=False, + + self._configuration = SQLDatabaseSchema( + **config_dict["config"], ) + + database_configuration_class = get_database_configuration_class(self._configuration.connection.engine) + + if self._configuration.connection.engine == DatabaseEngineType.SQLITE: + self._reader_configuration = database_configuration_class( + engine=self._configuration.connection.engine, + dbpath=self._configuration.connection.database_path, + ) + else: + self._reader_configuration = database_configuration_class( + engine=self._configuration.connection.engine, + user=self._configuration.connection.username, + password=self._configuration.connection.password, + host=self._configuration.connection.host, + port=self._configuration.connection.port, + dbname=self._configuration.connection.database_name, + use_ssl=False, + ) self._source_name = self.datasource.name @staticmethod def name() -> str: - return "Postgres" + return "SQL" @staticmethod def slug() -> str: - return "postgres" + return "sql" @staticmethod def description() -> str: - return "Connect to a Postgres database" + return "Connect to a SQL Database" # This static method takes a dictionary for configuration and a DataSource object as inputs. # Validation of these inputs is performed and a dictionary containing the - # Postgres Connection Configuration is returned. + # Database Connection Configuration is returned. @staticmethod def process_validate_config( config_data: dict, datasource: DataSource, ) -> dict: - return PostgresConnectionConfiguration( - postgres_config=config_data, + return SQLConnectionConfiguration( + config=config_data, ).to_dict( encrypt_fn=datasource.profile.encrypt_value, ) - # This static method returns the provider slug for the datasource - # connector. @staticmethod def provider_slug() -> str: - return "postgres" + return "promptly" def validate_and_process(self, data: dict) -> List[DataSourceEntryItem]: raise NotImplementedError @@ -110,10 +154,10 @@ def add_entry(self, data: dict) -> Optional[DataSourceEntryItem]: raise NotImplementedError def similarity_search(self, query: str, **kwargs) -> List[dict]: - pg_client = PostgresReader() + client = DatabaseReader() result = ( - pg_client.process( - PostgresReaderInput( + client.process( + DatabaseReaderInput( sql=query, ), configuration=self._reader_configuration, @@ -148,10 +192,10 @@ def similarity_search(self, query: str, **kwargs) -> List[dict]: ] def hybrid_search(self, query: str, **kwargs) -> List[dict]: - pg_client = PostgresReader() + client = DatabaseReader() result = ( - pg_client.process( - PostgresReaderInput( + client.process( + DatabaseReaderInput( sql=query, ), configuration=self._reader_configuration, diff --git a/llmstack/datasources/urls.py b/llmstack/datasources/urls.py index fc42285edc3..3c4cabbc4c9 100644 --- a/llmstack/datasources/urls.py +++ b/llmstack/datasources/urls.py @@ -8,10 +8,6 @@ "api/datasource_types", apis.DataSourceTypeViewSet.as_view({"get": "get"}), ), - path( - "api/datasource_types", - apis.DataSourceTypeViewSet.as_view({"get": "get"}), - ), # Data sources path( "api/datasources", diff --git a/poetry.lock b/poetry.lock index 65a99e360bf..5e2d796a6cd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3300,6 +3300,24 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "mysqlclient" +version = "2.2.4" +description = "Python interface to MySQL" +optional = false +python-versions = ">=3.8" +files = [ + {file = "mysqlclient-2.2.4-cp310-cp310-win_amd64.whl", hash = "sha256:ac44777eab0a66c14cb0d38965572f762e193ec2e5c0723bcd11319cc5b693c5"}, + {file = "mysqlclient-2.2.4-cp311-cp311-win_amd64.whl", hash = "sha256:329e4eec086a2336fe3541f1ce095d87a6f169d1cc8ba7b04ac68bcb234c9711"}, + {file = "mysqlclient-2.2.4-cp312-cp312-win_amd64.whl", hash = "sha256:e1ebe3f41d152d7cb7c265349fdb7f1eca86ccb0ca24a90036cde48e00ceb2ab"}, + {file = "mysqlclient-2.2.4-cp38-cp38-win_amd64.whl", hash = "sha256:3c318755e06df599338dad7625f884b8a71fcf322a9939ef78c9b3db93e1de7a"}, + {file = "mysqlclient-2.2.4-cp39-cp39-win_amd64.whl", hash = "sha256:9d4c015480c4a6b2b1602eccd9846103fc70606244788d04aa14b31c4bd1f0e2"}, + {file = "mysqlclient-2.2.4-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:d43987bb9626096a302ca6ddcdd81feaeca65ced1d5fe892a6a66b808326aa54"}, + {file = "mysqlclient-2.2.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:4e80dcad884dd6e14949ac6daf769123223a52a6805345608bf49cdaf7bc8b3a"}, + {file = "mysqlclient-2.2.4-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:9d3310295cb682232cadc28abd172f406c718b9ada41d2371259098ae37779d3"}, + {file = "mysqlclient-2.2.4.tar.gz", hash = "sha256:33bc9fb3464e7d7c10b1eaf7336c5ff8f2a3d3b88bab432116ad2490beb3bf41"}, +] + [[package]] name = "ncclient" version = "0.6.15" @@ -7849,4 +7867,4 @@ networking = ["junos-eznc"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "5ad8a217f0ad162ed2bc39aa51877b77d7ca31bf8bde51059327387dc6572e9c" +content-hash = "d9855460e88ca9ea5f706bd28bc271745eda573480e66f10da686f8ce3f78673" diff --git a/pyproject.toml b/pyproject.toml index 6a80b38ddf3..7182a78728d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,6 +128,7 @@ croniter = "^2.0.1" drf-yaml = "^3.0.1" google-generativeai = "^0.3.1" scrapy-playwright = "^0.0.33" +mysqlclient = "^2.2.4" django-ratelimit = "^4.1.0" [tool.poetry.extras]