Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_read(self):
user="root",
password="",
host="localhost",
port=5432,
port=3306,
dbname="usersdb",
)
reader_input = DatabaseReaderInput(
Expand Down
35 changes: 27 additions & 8 deletions llmstack/datasources/handlers/databases/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic import Field
from typing_extensions import Literal

from llmstack.base.models import Profile
from llmstack.common.blocks.base.schema import BaseSchema as _Schema
from llmstack.common.blocks.data.store.database.database_reader import (
DatabaseReader,
Expand All @@ -16,6 +17,7 @@
)
from llmstack.common.blocks.data.store.vectorstore import Document
from llmstack.common.utils.models import Config
from llmstack.connections.models import ConnectionType
from llmstack.datasources.handlers.datasource_processor import (
DataSourceEntryItem,
DataSourceProcessor,
Expand All @@ -33,8 +35,6 @@ class PostgreSQLConnection(_Schema):
description="Port number to connect to the PostgreSQL instance",
)
database_name: str = Field(description="PostgreSQL database name")
username: str = Field(description="PostgreSQL database username")
password: Optional[str] = Field(description="PostgreSQL database password")

class Config:
title = "PostgreSQL"
Expand All @@ -47,16 +47,14 @@ class MySQLConnection(_Schema):
description="Port number to connect to the MySQL instance",
)
database_name: str = Field(description="MySQL database name")
username: str = Field(description="MySQL database username")
password: Optional[str] = Field(description="MySQL database password")

class Config:
title = "MySQL"


class SQLiteConnection(_Schema):
engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE
database_path: str = Field(description="MySQL database name")
database_path: str = Field(description="SQLite database file path")

class Config:
title = "SQLite"
Expand All @@ -68,7 +66,12 @@ class Config:
class SQLDatabaseSchema(DataSourceSchema):
connection: Optional[SQLConnection] = Field(
title="Database",
description="Database details",
# description="Database details",
)
connection_id: Optional[str] = Field(
widget="connection",
advanced_parameter=False,
description="Use your authenticated connection to the database",
)


Expand All @@ -84,6 +87,8 @@ class SQLDataSource(DataSourceProcessor[SQLDatabaseSchema]):
# configuration, and sets up Weaviate Database Configuration.
def __init__(self, datasource: DataSource):
self.datasource = datasource
self.profile = Profile.objects.get(user=self.datasource.owner)
self._env = self.profile.get_vendor_env()

if self.datasource.config and "data" in self.datasource.config:
config_dict = SQLConnectionConfiguration().from_dict(
Expand All @@ -103,10 +108,24 @@ def __init__(self, datasource: DataSource):
dbpath=self._configuration.connection.database_path,
)
else:
username = password = None

connection = (
self._env["connections"].get(
self._configuration.connection_id,
None,
)
if self._configuration.connection_id
else None
)
if connection and connection["base_connection_type"] == ConnectionType.CREDENTIALS:
username = connection["configuration"]["username"]
password = connection["configuration"]["password"]

self._reader_configuration = database_configuration_class(
engine=self._configuration.connection.engine,
user=self._configuration.connection.username,
password=self._configuration.connection.password,
user=username,
password=password,
host=self._configuration.connection.host,
port=self._configuration.connection.port,
dbname=self._configuration.connection.database_name,
Expand Down