diff --git a/llmstack/common/tests/blocks/data/store/database/test_database_reader.py b/llmstack/common/tests/blocks/data/store/database/test_database_reader.py index d6d3004a5ec..15edd9dac2c 100644 --- a/llmstack/common/tests/blocks/data/store/database/test_database_reader.py +++ b/llmstack/common/tests/blocks/data/store/database/test_database_reader.py @@ -15,7 +15,7 @@ def test_read(self): user="root", password="", host="localhost", - port=5432, + port=3306, dbname="usersdb", ) reader_input = DatabaseReaderInput( diff --git a/llmstack/datasources/handlers/databases/sql.py b/llmstack/datasources/handlers/databases/sql.py index 812b418f640..44d93b32d7b 100644 --- a/llmstack/datasources/handlers/databases/sql.py +++ b/llmstack/datasources/handlers/databases/sql.py @@ -5,6 +5,7 @@ from pydantic import Field from typing_extensions import Literal +from llmstack.base.models import Profile from llmstack.common.blocks.base.schema import BaseSchema as _Schema from llmstack.common.blocks.data.store.database.database_reader import ( DatabaseReader, @@ -16,6 +17,7 @@ ) from llmstack.common.blocks.data.store.vectorstore import Document from llmstack.common.utils.models import Config +from llmstack.connections.models import ConnectionType from llmstack.datasources.handlers.datasource_processor import ( DataSourceEntryItem, DataSourceProcessor, @@ -33,8 +35,6 @@ class PostgreSQLConnection(_Schema): description="Port number to connect to the PostgreSQL instance", ) database_name: str = Field(description="PostgreSQL database name") - username: str = Field(description="PostgreSQL database username") - password: Optional[str] = Field(description="PostgreSQL database password") class Config: title = "PostgreSQL" @@ -47,8 +47,6 @@ class MySQLConnection(_Schema): description="Port number to connect to the MySQL instance", ) database_name: str = Field(description="MySQL database name") - username: str = Field(description="MySQL database username") - password: Optional[str] = Field(description="MySQL database password") class Config: title = "MySQL" @@ -56,7 +54,7 @@ class Config: class SQLiteConnection(_Schema): engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE - database_path: str = Field(description="MySQL database name") + database_path: str = Field(description="SQLite database file path") class Config: title = "SQLite" @@ -68,7 +66,12 @@ class Config: class SQLDatabaseSchema(DataSourceSchema): connection: Optional[SQLConnection] = Field( title="Database", - description="Database details", + # description="Database details", + ) + connection_id: Optional[str] = Field( + widget="connection", + advanced_parameter=False, + description="Use your authenticated connection to the database", ) @@ -84,6 +87,8 @@ class SQLDataSource(DataSourceProcessor[SQLDatabaseSchema]): # configuration, and sets up Weaviate Database Configuration. def __init__(self, datasource: DataSource): self.datasource = datasource + self.profile = Profile.objects.get(user=self.datasource.owner) + self._env = self.profile.get_vendor_env() if self.datasource.config and "data" in self.datasource.config: config_dict = SQLConnectionConfiguration().from_dict( @@ -103,10 +108,24 @@ def __init__(self, datasource: DataSource): dbpath=self._configuration.connection.database_path, ) else: + username = password = None + + connection = ( + self._env["connections"].get( + self._configuration.connection_id, + None, + ) + if self._configuration.connection_id + else None + ) + if connection and connection["base_connection_type"] == ConnectionType.CREDENTIALS: + username = connection["configuration"]["username"] + password = connection["configuration"]["password"] + self._reader_configuration = database_configuration_class( engine=self._configuration.connection.engine, - user=self._configuration.connection.username, - password=self._configuration.connection.password, + user=username, + password=password, host=self._configuration.connection.host, port=self._configuration.connection.port, dbname=self._configuration.connection.database_name,