Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"editor.tabSize": 4,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.codeActionsOnSave": {
"source.organizeImports": true
"source.organizeImports": "explicit"
}
},
"isort.args": ["--profile", "black"],
Expand Down
7 changes: 7 additions & 0 deletions llmstack/common/blocks/data/store/database/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import StrEnum


class DatabaseEngineType(StrEnum):
POSTGRESQL = "postgresql"
MYSQL = "mysql"
SQLITE = "sqlite"
Original file line number Diff line number Diff line change
@@ -1,50 +1,23 @@
import collections
import datetime
import json
from collections import defaultdict
from datetime import datetime
from uuid import UUID
import uuid

import sqlalchemy
from psycopg2.extras import Range

from llmstack.common.blocks.base.processor import ProcessorInterface
from llmstack.common.blocks.base.schema import BaseSchema
from llmstack.common.blocks.data import DataDocument
from llmstack.common.blocks.data.store.postgres import (
PostgresConfiguration,
PostgresOutput,
get_pg_connection,
from llmstack.common.blocks.data.store.database.utils import (
DatabaseConfiguration,
DatabaseConfigurationType,
DatabaseOutput,
get_database_connection,
)


class PostgresReaderInput(BaseSchema):
sql: str


types_map = {
20: "integer",
21: "integer",
23: "integer",
700: "float",
1700: "float",
701: "float",
16: "boolean",
1082: "date",
1182: "date",
1114: "datetime",
1184: "datetime",
1115: "datetime",
1185: "datetime",
1014: "string",
1015: "string",
1008: "string",
1009: "string",
2951: "string",
1043: "string",
1002: "string",
1003: "string",
}


class PostgreSQLJSONEncoder(json.JSONEncoder):
class DatabaseJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, Range):
# From: https://github.com/psycopg/psycopg2/pull/779
Expand All @@ -64,20 +37,24 @@ def default(self, o):
]

return "".join(items)
elif isinstance(o, UUID):
elif isinstance(o, uuid.UUID):
return str(o.hex)
elif isinstance(o, datetime):
elif isinstance(o, (datetime.date, datetime.datetime)):
return o.isoformat()

return super(PostgreSQLJSONEncoder, self).default(o)
return super().default(o)


class PostgresReader(
ProcessorInterface[PostgresReaderInput, PostgresOutput, PostgresConfiguration],
class DatabaseReaderInput(BaseSchema):
sql: str


class DatabaseReader(
ProcessorInterface[DatabaseReaderInput, DatabaseOutput, DatabaseConfiguration],
):
def fetch_columns(self, columns):
column_names = set()
duplicates_counters = defaultdict(int)
duplicates_counters = collections.defaultdict(int)
new_columns = []

for col in columns:
Expand All @@ -90,35 +67,35 @@ def fetch_columns(self, columns):
)

column_names.add(column_name)
new_columns.append(
{"name": column_name, "friendly_name": column_name, "type": col[1]},
)
new_columns.append({"name": column_name, "type": col[1]})

return new_columns

def process(
self,
input: PostgresReaderInput,
configuration: PostgresConfiguration,
) -> PostgresOutput:
connection = get_pg_connection(configuration.dict())
cursor = connection.cursor()
input: DatabaseReaderInput,
configuration: DatabaseConfigurationType,
) -> DatabaseOutput:
connection = get_database_connection(configuration=configuration)
try:
cursor.execute(input.sql)
result = connection.execute(sqlalchemy.text(input.sql))
cursor = result.cursor

if cursor.description is not None:
columns = self.fetch_columns(
[(i[0], types_map.get(i[1], None)) for i in cursor.description],
[(i[0], None) for i in cursor.description],
)
rows = [dict(zip((column["name"] for column in columns), row)) for row in cursor]

data = {"columns": columns, "rows": rows}
json_data = json.dumps(data, cls=PostgreSQLJSONEncoder)
json_data = json.dumps(data, cls=DatabaseJSONEncoder)
else:
raise Exception("Query completed but it returned no data.")
except Exception as e:
connection.cancel()
connection.close()
connection.engine.dispose()
raise e
return PostgresOutput(
return DatabaseOutput(
documents=[
DataDocument(
content=json_data,
Expand Down
88 changes: 88 additions & 0 deletions llmstack/common/blocks/data/store/database/mysql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from enum import Enum
from typing import List, Optional

from typing_extensions import Literal

from llmstack.common.blocks.base.schema import BaseSchema
from llmstack.common.blocks.data import DataDocument
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType

try:
import MySQLdb

enabled = True
except ImportError:
enabled = False


class SSLMode(str, Enum):
disabled = "DISABLED"
preferred = "PREFERRED"
required = "REQUIRED"
verify_ca = "VERIFY_CA"
verify_identity = "VERIFY_IDENTITY"


class MySQLConfiguration(BaseSchema):
engine: Literal[DatabaseEngineType.MYSQL] = DatabaseEngineType.MYSQL
user: Optional[str]
password: Optional[str]
host: str = "127.0.0.1"
port: int = 3306
dbname: str
use_ssl: bool = False
sslmode: SSLMode = "preferred"
ssl_ca: Optional[str] = None
ssl_cert: Optional[str] = None
ssl_key: Optional[str] = None

class Config:
schema_extra = {
"order": ["host", "port", "user", "password"],
"required": ["dbname"],
"secret": ["password", "ssl_ca", "ssl_cert", "ssl_key"],
"extra_options": ["sslmode", "ssl_ca", "ssl_cert", "ssl_key"],
}


class MySQLOutput(BaseSchema):
documents: List[DataDocument]


def get_mysql_ssl_config(configuration: dict):
if not configuration.get("use_ssl"):
return {}

ssl_config = {"sslmode": configuration.get("sslmode", "prefer")}

if configuration.get("use_ssl"):
config_map = {"ssl_mode": "preferred", "ssl_cacert": "ca", "ssl_cert": "cert", "ssl_key": "key"}
for key, cfg in config_map.items():
val = configuration.get(key)
if val:
ssl_config[cfg] = val

return ssl_config


def get_mysql_connection(configuration: dict):
params = dict(
host=configuration.get("host"),
user=configuration.get("user"),
passwd=configuration.get("password"),
db=configuration.get("dbname"),
port=configuration.get("port", 3306),
charset=configuration.get("charset", "utf8"),
use_unicode=configuration.get("use_unicode", True),
connect_timeout=configuration.get("connect_timeout", 60),
autocommit=configuration.get("autocommit", True),
)

ssl_options = get_mysql_ssl_config()

if ssl_options:
params["ssl"] = ssl_options

connection = MySQLdb.connect(**params)

return connection
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
from typing import List, Optional

import psycopg2
from typing_extensions import Literal

from llmstack.common.blocks.base.schema import BaseSchema
from llmstack.common.blocks.data import DataDocument
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType


class SSLMode(str, Enum):
Expand All @@ -19,6 +21,7 @@ class SSLMode(str, Enum):


class PostgresConfiguration(BaseSchema):
engine: Literal[DatabaseEngineType.POSTGRESQL] = DatabaseEngineType.POSTGRESQL
user: Optional[str]
password: Optional[str]
host: str = "127.0.0.1"
Expand Down Expand Up @@ -53,7 +56,10 @@ def _create_cert_file(configuration, key, ssl_config):
ssl_config[key] = cert_file.name


def _get_ssl_config(configuration: dict):
def get_pg_ssl_config(configuration: dict):
if not configuration.get("use_ssl"):
return {}

ssl_config = {"sslmode": configuration.get("sslmode", "prefer")}

_create_cert_file(configuration, "sslrootcert", ssl_config)
Expand All @@ -65,7 +71,7 @@ def _get_ssl_config(configuration: dict):

def get_pg_connection(configuration: dict):
ssl_config = (
_get_ssl_config(
get_pg_ssl_config(
configuration,
)
if configuration.get("use_ssl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from typing import List

from typing_extensions import Literal

from llmstack.common.blocks.base.schema import BaseSchema
from llmstack.common.blocks.data import DataDocument
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType


class SQLiteConfiguration(BaseSchema):
engine: Literal[DatabaseEngineType.SQLITE] = DatabaseEngineType.SQLITE
dbpath: str


Expand Down
95 changes: 95 additions & 0 deletions llmstack/common/blocks/data/store/database/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import List, TypeVar

import sqlalchemy

from llmstack.common.blocks.base.schema import BaseSchema
from llmstack.common.blocks.data import DataDocument
from llmstack.common.blocks.data.store.database.constants import DatabaseEngineType
from llmstack.common.blocks.data.store.database.mysql import (
MySQLConfiguration,
get_mysql_ssl_config,
)
from llmstack.common.blocks.data.store.database.postgresql import (
PostgresConfiguration,
get_pg_ssl_config,
)
from llmstack.common.blocks.data.store.database.sqlite import SQLiteConfiguration

DATABASES = {
DatabaseEngineType.POSTGRESQL: {
"name": "PostgreSQL",
"driver": "postgresql+psycopg2",
},
DatabaseEngineType.MYSQL: {
"name": "MySQL",
"driver": "mysql+mysqldb",
},
DatabaseEngineType.SQLITE: {
"name": "SQLite",
"driver": "sqlite+pysqlite",
},
}

DatabaseConfiguration = MySQLConfiguration | PostgresConfiguration | SQLiteConfiguration

DatabaseConfigurationType = TypeVar("DatabaseConfigurationType", bound=DatabaseConfiguration)


class DatabaseOutput(BaseSchema):
documents: List[DataDocument]


def get_database_configuration_class(engine: DatabaseEngineType) -> DatabaseConfigurationType:
if engine == DatabaseEngineType.POSTGRESQL:
return PostgresConfiguration
elif engine == DatabaseEngineType.MYSQL:
return MySQLConfiguration
elif engine == DatabaseEngineType.SQLITE:
return SQLiteConfiguration
else:
raise ValueError(f"Unsupported database engine: {engine}")


def get_ssl_config(configuration: DatabaseConfigurationType) -> dict:
ssl_config = {}
if configuration.engine == DatabaseEngineType.POSTGRESQL:
ssl_config = get_pg_ssl_config(configuration.dict())
elif configuration.engine == DatabaseEngineType.MYSQL:
ssl_config = get_mysql_ssl_config(configuration.dict())
return ssl_config


def get_database_connection(
configuration: DatabaseConfigurationType,
ssl_config: dict = None,
) -> sqlalchemy.engine.Connection:
if configuration.engine not in DATABASES:
raise ValueError(f"Unsupported database engine: {configuration.engine}")

if not ssl_config:
ssl_config = get_ssl_config(configuration)

database_name = configuration.dbpath if configuration.engine == DatabaseEngineType.SQLITE else configuration.dbname

connect_args: dict = {}

if ssl_config:
connect_args["ssl"] = ssl_config

# Create URL
db_url = sqlalchemy.engine.URL.create(
drivername=DATABASES[configuration.engine]["driver"],
username=configuration.user if hasattr(configuration, "user") else None,
password=configuration.password if hasattr(configuration, "password") else None,
host=configuration.host if hasattr(configuration, "host") else None,
port=configuration.port if hasattr(configuration, "port") else None,
database=database_name,
)

# Create engine
engine = sqlalchemy.create_engine(db_url, connect_args=connect_args)

# Connect to the database
connection = engine.connect()

return connection
Loading