diff --git a/.dockerignore b/.dockerignore
new file mode 100644
index 0000000..4cafb10
--- /dev/null
+++ b/.dockerignore
@@ -0,0 +1,9 @@
+*.pyc
+.venv*
+.vscode
+.mypy_cache
+.coverage
+htmlcov
+
+dist
+test.py
diff --git a/.flake8 b/.flake8
new file mode 100644
index 0000000..9d5b626
--- /dev/null
+++ b/.flake8
@@ -0,0 +1,8 @@
+[flake8]
+max-line-length = 88
+ignore = E203, E241, E501, W503, F811
+exclude =
+ .git,
+ __pycache__
+ .history
+ tests/demo_project
diff --git a/.github/dependabot.yml b/.github/dependabot.yml
new file mode 100644
index 0000000..b9038ca
--- /dev/null
+++ b/.github/dependabot.yml
@@ -0,0 +1,10 @@
+version: 2
+updates:
+ - package-ecosystem: "pip"
+ directory: "/"
+ schedule:
+ interval: "monthly"
+ - package-ecosystem: "github-actions"
+ directory: "/"
+ schedule:
+ interval: monthly
diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml
new file mode 100644
index 0000000..4d56161
--- /dev/null
+++ b/.github/workflows/publish.yml
@@ -0,0 +1,25 @@
+name: Publish
+
+on:
+ release:
+ types: [published]
+ workflow_dispatch:
+
+jobs:
+ publish:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+ - name: Install Flit
+ run: pip install flit
+ - name: Install Dependencies
+ run: flit install --symlink
+ - name: Publish
+ env:
+ FLIT_USERNAME: ${{ secrets.FLIT_USERNAME }}
+ FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }}
+ run: flit publish
diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
new file mode 100644
index 0000000..860f59f
--- /dev/null
+++ b/.github/workflows/test.yml
@@ -0,0 +1,25 @@
+name: Test
+
+on:
+ push:
+ pull_request:
+ types: [assigned, opened, synchronize, reopened]
+
+jobs:
+ test_coverage:
+ runs-on: ubuntu-latest
+
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+ - name: Install Flit
+ run: pip install flit
+ - name: Install Dependencies
+ run: flit install --symlink
+ - name: Test
+ run: make test-cov
+ - name: Coverage
+ uses: codecov/codecov-action@v3.1.4
diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml
new file mode 100644
index 0000000..f459fdd
--- /dev/null
+++ b/.github/workflows/test_full.yml
@@ -0,0 +1,43 @@
+name: Full Test
+
+on:
+ push:
+ pull_request:
+ types: [assigned, opened, synchronize, reopened]
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
+
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: ${{ matrix.python-version }}
+ - name: Install Flit
+ run: pip install flit
+ - name: Install Dependencies
+ run: flit install --symlink
+ - name: Test
+ run: pytest tests
+
+ codestyle:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+ - name: Set up Python
+ uses: actions/setup-python@v4
+ with:
+ python-version: 3.8
+ - name: Install Flit
+ run: pip install flit
+ - name: Install Dependencies
+ run: flit install --symlink
+ - name: Linting check
+ run: ruff check ellar_sqlalchemy tests
+ - name: mypy
+ run: mypy ellar_sqlalchemy
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..a47d824
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,129 @@
+*.pyc
+
+# Byte-compiled / optimized / DLL files
+__pycache__
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+# *.mo Needs to come with the package
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+.vscode
+.mypy_cache
+.coverage
+htmlcov
+
+dist
+test.py
+
+docs/site
+
+.DS_Store
+.idea
+local_install.sh
+dist
+test.py
+
+docs/site
+site/
diff --git a/.isort.cfg b/.isort.cfg
new file mode 100644
index 0000000..1815422
--- /dev/null
+++ b/.isort.cfg
@@ -0,0 +1,3 @@
+[settings]
+profile = black
+combine_as_imports = true
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000..0fbadcd
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,46 @@
+repos:
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v2.3.0
+ hooks:
+ - id: check-merge-conflict
+- repo: https://github.com/asottile/yesqa
+ rev: v1.3.0
+ hooks:
+ - id: yesqa
+- repo: local
+ hooks:
+ - id: code_formatting
+ args: []
+ name: Code Formatting
+ entry: "make fmt"
+ types: [python]
+ language_version: python3.8
+ language: python
+ - id: code_linting
+ args: [ ]
+ name: Code Linting
+ entry: "make lint"
+ types: [ python ]
+ language_version: python3.8
+ language: python
+- repo: https://github.com/pre-commit/pre-commit-hooks
+ rev: v2.3.0
+ hooks:
+ - id: end-of-file-fixer
+ exclude: >-
+ ^examples/[^/]*\.svg$
+ - id: requirements-txt-fixer
+ - id: trailing-whitespace
+ types: [python]
+ - id: check-case-conflict
+ - id: check-json
+ - id: check-xml
+ - id: check-executables-have-shebangs
+ - id: check-toml
+ - id: check-xml
+ - id: check-yaml
+ - id: debug-statements
+ - id: check-added-large-files
+ - id: check-symlinks
+ - id: debug-statements
+ exclude: ^tests/
diff --git a/Makefile b/Makefile
new file mode 100644
index 0000000..0714b54
--- /dev/null
+++ b/Makefile
@@ -0,0 +1,38 @@
+.PHONY: help docs
+.DEFAULT_GOAL := help
+
+help:
+ @fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
+
+clean: ## Removing cached python compiled files
+ find . -name \*pyc | xargs rm -fv
+ find . -name \*pyo | xargs rm -fv
+ find . -name \*~ | xargs rm -fv
+ find . -name __pycache__ | xargs rm -rfv
+ find . -name .ruff_cache | xargs rm -rfv
+
+install: ## Install dependencies
+ flit install --deps develop --symlink
+
+install-full: ## Install dependencies
+ make install
+ pre-commit install -f
+
+lint:fmt ## Run code linters
+ ruff check ellar_sqlalchemy tests
+ mypy ellar_sqlalchemy
+
+fmt format:clean ## Run code formatters
+ ruff format ellar_sqlalchemy tests
+ ruff check --fix ellar_sqlalchemy tests
+
+test: ## Run tests
+ pytest tests
+
+test-cov: ## Run tests with coverage
+ pytest --cov=ellar_sqlalchemy --cov-report term-missing tests
+
+pre-commit-lint: ## Runs Requires commands during pre-commit
+ make clean
+ make fmt
+ make lint
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..a14fc58
--- /dev/null
+++ b/README.md
@@ -0,0 +1,157 @@
+
+
+
+
+
+
+[](https://badge.fury.io/py/ellar-sqlachemy)
+[](https://pypi.python.org/pypi/ellar-sqlachemy)
+[](https://pypi.python.org/pypi/ellar-sqlachemy)
+
+## Project Status
+- 70% done
+- SQLAlchemy Table support with ModelSession
+- Migration custom revision directives
+- Documentation
+- File Field
+- Image Field
+
+## Introduction
+Ellar SQLAlchemy Module simplifies the integration of SQLAlchemy and Alembic migration tooling into your ellar application.
+
+## Installation
+```shell
+$(venv) pip install ellar-sqlalchemy
+```
+
+## Features
+- Automatic table name
+- Session management during request and after request
+- Support both async/sync SQLAlchemy operations in Session, Engine, and Connection.
+- Multiple Database Support
+- Database migrations for both single and multiple databases either async/sync database engine
+
+## **Usage**
+In your ellar application, create a module called `db` or any name of your choice,
+```shell
+ellar create-module db
+```
+Then, in `models/base.py` define your model base as shown below:
+
+```python
+# db/models/base.py
+from datetime import datetime
+from sqlalchemy import DateTime, func
+from sqlalchemy.orm import Mapped, mapped_column
+from ellar_sqlalchemy.model import Model
+
+
+class Base(Model, as_base=True):
+ __database__ = 'default'
+
+ created_date: Mapped[datetime] = mapped_column(
+ "created_date", DateTime, default=datetime.utcnow, nullable=False
+ )
+
+ time_updated: Mapped[datetime] = mapped_column(
+ "time_updated", DateTime, nullable=False, default=datetime.utcnow, onupdate=func.now()
+ )
+```
+
+Use `Base` to create other models, like users in `User` in
+```python
+# db/models/users.py
+from sqlalchemy import Integer, String
+from sqlalchemy.orm import Mapped, mapped_column
+from .base import Base
+
+
+class User(Base):
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ username: Mapped[str] = mapped_column(String, unique=True, nullable=False)
+ email: Mapped[str] = mapped_column(String)
+```
+
+### Configure Module
+```python
+# db/module.py
+from ellar.app import App
+from ellar.common import Module, IApplicationStartup
+from ellar.core import ModuleBase
+from ellar.di import Container
+from ellar_sqlalchemy import EllarSQLAlchemyModule, EllarSQLAlchemyService
+
+from .controllers import DbController
+
+@Module(
+ controllers=[DbController],
+ providers=[],
+ routers=[],
+ modules=[
+ EllarSQLAlchemyModule.setup(
+ databases={
+ 'default': 'sqlite:///project.db',
+ },
+ echo=True,
+ migration_options={
+ 'directory': '__main__/migrations'
+ },
+ models=['db.models.users']
+ )
+ ]
+)
+class DbModule(ModuleBase, IApplicationStartup):
+ """
+ Db Module
+ """
+
+ async def on_startup(self, app: App) -> None:
+ db_service = app.injector.get(EllarSQLAlchemyService)
+ db_service.create_all()
+
+ def register_providers(self, container: Container) -> None:
+ """for more complicated provider registrations, use container.register_instance(...) """
+```
+
+### Model Usage
+Database session exist at model level and can be accessed through `model.get_db_session()` eg, `User.get_db_session()`
+```python
+# db/models/controllers.py
+from ellar.common import Controller, ControllerBase, get, post, Body
+from pydantic import EmailStr
+from sqlalchemy import select
+
+from .models.users import User
+
+
+@Controller
+class DbController(ControllerBase):
+ @post("/users")
+ async def create_user(self, username: Body[str], email: Body[EmailStr]):
+ session = User.get_db_session()
+ user = User(username=username, email=email)
+
+ session.add(user)
+ session.commit()
+
+ return user.dict()
+
+
+ @get("/users/{user_id:int}")
+ def get_user_by_id(self, user_id: int):
+ session = User.get_db_session()
+ stmt = select(User).filter(User.id==user_id)
+ user = session.execute(stmt).scalar()
+ return user.dict()
+
+ @get("/users")
+ async def get_all_users(self):
+ session = User.get_db_session()
+ stmt = select(User)
+ rows = session.execute(stmt.offset(0).limit(100)).scalars()
+ return [row.dict() for row in rows]
+```
+
+## License
+
+Ellar is [MIT licensed](LICENSE).
diff --git a/ellar_sqlalchemy/__init__.py b/ellar_sqlalchemy/__init__.py
new file mode 100644
index 0000000..e50183a
--- /dev/null
+++ b/ellar_sqlalchemy/__init__.py
@@ -0,0 +1,8 @@
+"""Ellar SQLAlchemy Module - Adds support for SQLAlchemy and Alembic package to your Ellar web Framework"""
+
+__version__ = "0.0.1"
+
+from .module import EllarSQLAlchemyModule
+from .services import EllarSQLAlchemyService
+
+__all__ = ["EllarSQLAlchemyModule", "EllarSQLAlchemyService"]
diff --git a/ellar_sqlalchemy/cli/__init__.py b/ellar_sqlalchemy/cli/__init__.py
new file mode 100644
index 0000000..6951cd7
--- /dev/null
+++ b/ellar_sqlalchemy/cli/__init__.py
@@ -0,0 +1,3 @@
+from .commands import db as DBCommands
+
+__all__ = ["DBCommands"]
diff --git a/ellar_sqlalchemy/cli/commands.py b/ellar_sqlalchemy/cli/commands.py
new file mode 100644
index 0000000..ba624fe
--- /dev/null
+++ b/ellar_sqlalchemy/cli/commands.py
@@ -0,0 +1,404 @@
+import click
+from ellar.app import current_injector
+
+from ellar_sqlalchemy.services import EllarSQLAlchemyService
+
+from .handlers import CLICommandHandlers
+
+
+@click.group()
+def db():
+ """- Perform Alembic Database Commands -"""
+ pass
+
+
+def _get_handler_context(ctx: click.Context) -> CLICommandHandlers:
+ db_service = current_injector.get(EllarSQLAlchemyService)
+ return CLICommandHandlers(db_service)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option("-m", "--message", default=None, help="Revision message")
+@click.option(
+ "--autogenerate",
+ is_flag=True,
+ help=(
+ "Populate revision script with candidate migration "
+ "operations, based on comparison of database to model"
+ ),
+)
+@click.option(
+ "--sql",
+ is_flag=True,
+ help="Don't emit SQL to database - dump to standard output " "instead",
+)
+@click.option(
+ "--head",
+ default="head",
+ help="Specify head revision or @head to base new " "revision on",
+)
+@click.option(
+ "--splice",
+ is_flag=True,
+ help='Allow a non-head revision as the "head" to splice onto',
+)
+@click.option(
+ "--branch-label",
+ default=None,
+ help="Specify a branch label to apply to the new revision",
+)
+@click.option(
+ "--version-path",
+ default=None,
+ help="Specify specific path from config for version file",
+)
+@click.option(
+ "--rev-id",
+ default=None,
+ help="Specify a hardcoded revision id instead of generating " "one",
+)
+@click.pass_context
+def revision(
+ ctx,
+ directory,
+ message,
+ autogenerate,
+ sql,
+ head,
+ splice,
+ branch_label,
+ version_path,
+ rev_id,
+):
+ """- Create a new revision file."""
+ handler = _get_handler_context(ctx)
+ handler.revision(
+ directory,
+ message,
+ autogenerate,
+ sql,
+ head,
+ splice,
+ branch_label,
+ version_path,
+ rev_id,
+ )
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option("-m", "--message", default=None, help="Revision message")
+@click.option(
+ "--sql",
+ is_flag=True,
+ help="Don't emit SQL to database - dump to standard output " "instead",
+)
+@click.option(
+ "--head",
+ default="head",
+ help="Specify head revision or @head to base new " "revision on",
+)
+@click.option(
+ "--splice",
+ is_flag=True,
+ help='Allow a non-head revision as the "head" to splice onto',
+)
+@click.option(
+ "--branch-label",
+ default=None,
+ help="Specify a branch label to apply to the new revision",
+)
+@click.option(
+ "--version-path",
+ default=None,
+ help="Specify specific path from config for version file",
+)
+@click.option(
+ "--rev-id",
+ default=None,
+ help="Specify a hardcoded revision id instead of generating " "one",
+)
+@click.option(
+ "-x",
+ "--x-arg",
+ multiple=True,
+ help="Additional arguments consumed by custom env.py scripts",
+)
+@click.pass_context
+def migrate(
+ ctx,
+ directory,
+ message,
+ sql,
+ head,
+ splice,
+ branch_label,
+ version_path,
+ rev_id,
+ x_arg,
+):
+ """- Autogenerate a new revision file (Alias for
+ 'revision --autogenerate')"""
+ handler = _get_handler_context(ctx)
+ handler.migrate(
+ directory,
+ message,
+ sql,
+ head,
+ splice,
+ branch_label,
+ version_path,
+ rev_id,
+ x_arg,
+ )
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.argument("revision", default="head")
+@click.pass_context
+def edit(ctx, directory, revision):
+ """- Edit a revision file"""
+ handler = _get_handler_context(ctx)
+ handler.edit(directory, revision)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option("-m", "--message", default=None, help="Merge revision message")
+@click.option(
+ "--branch-label",
+ default=None,
+ help="Specify a branch label to apply to the new revision",
+)
+@click.option(
+ "--rev-id",
+ default=None,
+ help="Specify a hardcoded revision id instead of generating " "one",
+)
+@click.argument("revisions", nargs=-1)
+@click.pass_context
+def merge(ctx, directory, message, branch_label, rev_id, revisions):
+ """- Merge two revisions together, creating a new revision file"""
+ handler = _get_handler_context(ctx)
+ handler.merge(directory, revisions, message, branch_label, rev_id)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option(
+ "--sql",
+ is_flag=True,
+ help="Don't emit SQL to database - dump to standard output " "instead",
+)
+@click.option(
+ "--tag",
+ default=None,
+ help='Arbitrary "tag" name - can be used by custom env.py ' "scripts",
+)
+@click.option(
+ "-x",
+ "--x-arg",
+ multiple=True,
+ help="Additional arguments consumed by custom env.py scripts",
+)
+@click.argument("revision", default="head")
+@click.pass_context
+def upgrade(ctx, directory, sql, tag, x_arg, revision):
+ """- Upgrade to a later version"""
+ handler = _get_handler_context(ctx)
+ handler.upgrade(directory, revision, sql, tag, x_arg)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option(
+ "--sql",
+ is_flag=True,
+ help="Don't emit SQL to database - dump to standard output " "instead",
+)
+@click.option(
+ "--tag",
+ default=None,
+ help='Arbitrary "tag" name - can be used by custom env.py ' "scripts",
+)
+@click.option(
+ "-x",
+ "--x-arg",
+ multiple=True,
+ help="Additional arguments consumed by custom env.py scripts",
+)
+@click.argument("revision", default="-1")
+@click.pass_context
+def downgrade(ctx: click.Context, directory, sql, tag, x_arg, revision):
+ """- Revert to a previous version"""
+ handler = _get_handler_context(ctx)
+ handler.downgrade(directory, revision, sql, tag, x_arg)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.argument("revision", default="head")
+@click.pass_context
+def show(ctx: click.Context, directory, revision):
+ """- Show the revision denoted by the given symbol."""
+ handler = _get_handler_context(ctx)
+ handler.show(directory, revision)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option(
+ "-r",
+ "--rev-range",
+ default=None,
+ help="Specify a revision range; format is [start]:[end]",
+)
+@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output")
+@click.option(
+ "-i",
+ "--indicate-current",
+ is_flag=True,
+ help="Indicate current version (Alembic 0.9.9 or greater is " "required)",
+)
+@click.pass_context
+def history(ctx: click.Context, directory, rev_range, verbose, indicate_current):
+ """- List changeset scripts in chronological order."""
+ handler = _get_handler_context(ctx)
+ handler.history(directory, rev_range, verbose, indicate_current)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output")
+@click.option(
+ "--resolve-dependencies",
+ is_flag=True,
+ help="Treat dependency versions as down revisions",
+)
+@click.pass_context
+def heads(ctx: click.Context, directory, verbose, resolve_dependencies):
+ """- Show current available heads in the script directory"""
+ handler = _get_handler_context(ctx)
+ handler.heads(directory, verbose, resolve_dependencies)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output")
+@click.pass_context
+def branches(ctx, directory, verbose):
+ """- Show current branch points"""
+ handler = _get_handler_context(ctx)
+ handler.branches(directory, verbose)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output")
+@click.pass_context
+def current(ctx: click.Context, directory, verbose):
+ """- Display the current revision for each database."""
+ handler = _get_handler_context(ctx)
+ handler.current(directory, verbose)
+
+
+@db.command()
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option(
+ "--sql",
+ is_flag=True,
+ help="Don't emit SQL to database - dump to standard output " "instead",
+)
+@click.option(
+ "--tag",
+ default=None,
+ help='Arbitrary "tag" name - can be used by custom env.py ' "scripts",
+)
+@click.argument("revision", default="head")
+@click.pass_context
+def stamp(ctx: click.Context, directory, sql, tag, revision):
+ """- 'stamp' the revision table with the given revision; don't run any
+ migrations"""
+ handler = _get_handler_context(ctx)
+ handler.stamp(directory, revision, sql, tag)
+
+
+@db.command("init-migration")
+@click.option(
+ "-d",
+ "--directory",
+ default=None,
+ help='Migration script directory (default is "migrations")',
+)
+@click.option(
+ "--package",
+ is_flag=True,
+ help="Write empty __init__.py files to the environment and " "version locations",
+)
+@click.pass_context
+def init(ctx: click.Context, directory, package):
+ """Creates a new migration repository."""
+ handler = _get_handler_context(ctx)
+ handler.alembic_init(directory, package)
diff --git a/ellar_sqlalchemy/cli/handlers.py b/ellar_sqlalchemy/cli/handlers.py
new file mode 100644
index 0000000..4a29f40
--- /dev/null
+++ b/ellar_sqlalchemy/cli/handlers.py
@@ -0,0 +1,264 @@
+from __future__ import annotations
+
+import argparse
+import logging
+import os
+import sys
+import typing as t
+from functools import wraps
+from pathlib import Path
+
+from alembic import command
+from alembic.config import Config as AlembicConfig
+from alembic.util.exc import CommandError
+from ellar.app import App
+
+from ellar_sqlalchemy.services import EllarSQLAlchemyService
+
+log = logging.getLogger(__name__)
+RevIdType = t.Union[str, t.List[str], t.Tuple[str, ...]]
+
+
+class Config(AlembicConfig):
+ def get_template_directory(self) -> str:
+ package_dir = os.path.abspath(Path(__file__).parent.parent)
+ return os.path.join(package_dir, "templates")
+
+
+def _catch_errors(f: t.Callable) -> t.Callable: # type:ignore[type-arg]
+ @wraps(f)
+ def wrapped(*args: t.Any, **kwargs: t.Any) -> None:
+ try:
+ f(*args, **kwargs)
+ except (CommandError, RuntimeError) as exc:
+ log.error("Error: " + str(exc))
+ sys.exit(1)
+
+ return wrapped
+
+
+class CLICommandHandlers:
+ def __init__(self, db_service: EllarSQLAlchemyService) -> None:
+ self.db_service = db_service
+
+ def get_config(
+ self,
+ directory: t.Optional[t.Any] = None,
+ x_arg: t.Optional[t.Any] = None,
+ opts: t.Optional[t.Any] = None,
+ ) -> Config:
+ directory = (
+ str(directory) if directory else self.db_service.migration_options.directory
+ )
+
+ config = Config(os.path.join(directory, "alembic.ini"))
+
+ config.set_main_option("script_location", directory)
+ config.set_main_option(
+ "sqlalchemy.url", str(self.db_service.engine.url).replace("%", "%%")
+ )
+
+ if config.cmd_opts is None:
+ config.cmd_opts = argparse.Namespace()
+
+ for opt in opts or []:
+ setattr(config.cmd_opts, opt, True)
+
+ if not hasattr(config.cmd_opts, "x"):
+ if x_arg is not None:
+ config.cmd_opts.x = []
+
+ if isinstance(x_arg, list) or isinstance(x_arg, tuple):
+ for x in x_arg:
+ config.cmd_opts.x.append(x)
+ else:
+ config.cmd_opts.x.append(x_arg)
+ else:
+ config.cmd_opts.x = None
+ return config
+
+ @_catch_errors
+ def alembic_init(self, directory: str | None = None, package: bool = False) -> None:
+ """Creates a new migration repository"""
+ if directory is None:
+ directory = self.db_service.migration_options.directory
+
+ config = Config()
+ config.set_main_option("script_location", directory)
+ config.config_file_name = os.path.join(directory, "alembic.ini")
+
+ command.init(config, directory, template="basic", package=package)
+
+ @_catch_errors
+ def revision(
+ self,
+ directory: str | None = None,
+ message: str | None = None,
+ autogenerate: bool = False,
+ sql: bool = False,
+ head: str = "head",
+ splice: bool = False,
+ branch_label: RevIdType | None = None,
+ version_path: str | None = None,
+ rev_id: str | None = None,
+ ) -> None:
+ """Create a new revision file."""
+ opts = ["autogenerate"] if autogenerate else None
+
+ config = self.get_config(directory, opts=opts)
+ command.revision(
+ config,
+ message,
+ autogenerate=autogenerate,
+ sql=sql,
+ head=head,
+ splice=splice,
+ branch_label=branch_label,
+ version_path=version_path,
+ rev_id=rev_id,
+ )
+
+ @_catch_errors
+ def migrate(
+ self,
+ directory: str | None = None,
+ message: str | None = None,
+ sql: bool = False,
+ head: str = "head",
+ splice: bool = False,
+ branch_label: RevIdType | None = None,
+ version_path: str | None = None,
+ rev_id: str | None = None,
+ x_arg: str | None = None,
+ ) -> None:
+ """Alias for 'revision --autogenerate'"""
+ config = self.get_config(
+ directory,
+ opts=["autogenerate"],
+ x_arg=x_arg,
+ )
+ command.revision(
+ config,
+ message,
+ autogenerate=True,
+ sql=sql,
+ head=head,
+ splice=splice,
+ branch_label=branch_label,
+ version_path=version_path,
+ rev_id=rev_id,
+ )
+
+ @_catch_errors
+ def edit(self, directory: str | None = None, revision: str = "current") -> None:
+ """Edit current revision."""
+ config = self.get_config(directory)
+ command.edit(config, revision)
+
+ @_catch_errors
+ def merge(
+ self,
+ directory: str | None = None,
+ revisions: RevIdType = "",
+ message: str | None = None,
+ branch_label: RevIdType | None = None,
+ rev_id: str | None = None,
+ ) -> None:
+ """Merge two revisions together. Creates a new migration file"""
+ config = self.get_config(directory)
+ command.merge(
+ config, revisions, message=message, branch_label=branch_label, rev_id=rev_id
+ )
+
+ @_catch_errors
+ def upgrade(
+ self,
+ directory: str | None = None,
+ revision: str = "head",
+ sql: bool = False,
+ tag: str | None = None,
+ x_arg: str | None = None,
+ ) -> None:
+ """Upgrade to a later version"""
+ config = self.get_config(directory, x_arg=x_arg)
+ command.upgrade(config, revision, sql=sql, tag=tag)
+
+ @_catch_errors
+ def downgrade(
+ self,
+ directory: str | None = None,
+ revision: str = "-1",
+ sql: bool = False,
+ tag: str | None = None,
+ x_arg: str | None = None,
+ ) -> None:
+ """Revert to a previous version"""
+ config = self.get_config(directory, x_arg=x_arg)
+ if sql and revision == "-1":
+ revision = "head:-1"
+ command.downgrade(config, revision, sql=sql, tag=tag)
+
+ @_catch_errors
+ def show(self, directory: str | None = None, revision: str = "head") -> None:
+ """Show the revision denoted by the given symbol."""
+ config = self.get_config(directory)
+ command.show(config, revision) # type:ignore[no-untyped-call]
+
+ @_catch_errors
+ def history(
+ self,
+ directory: str | None = None,
+ rev_range: t.Any = None,
+ verbose: bool = False,
+ indicate_current: bool = False,
+ ) -> None:
+ """List changeset scripts in chronological order."""
+ config = self.get_config(directory)
+ command.history(
+ config, rev_range, verbose=verbose, indicate_current=indicate_current
+ )
+
+ @_catch_errors
+ def heads(
+ self,
+ directory: str | None = None,
+ verbose: bool = False,
+ resolve_dependencies: bool = False,
+ ) -> None:
+ """Show current available heads in the script directory"""
+ config = self.get_config(directory)
+ command.heads( # type:ignore[no-untyped-call]
+ config, verbose=verbose, resolve_dependencies=resolve_dependencies
+ )
+
+ @_catch_errors
+ def branches(self, directory: str | None = None, verbose: bool = False) -> None:
+ """Show current branch points"""
+ config = self.get_config(directory)
+ command.branches(config, verbose=verbose) # type:ignore[no-untyped-call]
+
+ @_catch_errors
+ def current(self, directory: str | None = None, verbose: bool = False) -> None:
+ """Display the current revision for each database."""
+ config = self.get_config(directory)
+ command.current(config, verbose=verbose)
+
+ @_catch_errors
+ def stamp(
+ self,
+ app: App,
+ directory: str | None = None,
+ revision: str = "head",
+ sql: bool = False,
+ tag: t.Any = None,
+ ) -> None:
+ """'stamp' the revision table with the given revision; don't run any
+ migrations"""
+ config = self.get_config(app, directory)
+ command.stamp(config, revision, sql=sql, tag=tag)
+
+ @_catch_errors
+ def check(self, app: App, directory: str | None = None) -> None:
+ """Check if there are any new operations to migrate"""
+ config = self.get_config(app, directory)
+ command.check(config)
diff --git a/ellar_sqlalchemy/constant.py b/ellar_sqlalchemy/constant.py
new file mode 100644
index 0000000..1931674
--- /dev/null
+++ b/ellar_sqlalchemy/constant.py
@@ -0,0 +1,11 @@
+import sqlalchemy.orm as sa_orm
+
+DATABASE_BIND_KEY = "database_bind_key"
+DEFAULT_KEY = "default"
+DATABASE_KEY = "__database__"
+TABLE_KEY = "__table__"
+ABSTRACT_KEY = "__abstract__"
+
+
+class DeclarativeBasePlaceHolder(sa_orm.DeclarativeBase):
+ pass
diff --git a/ellar_sqlalchemy/exceptions.py b/ellar_sqlalchemy/exceptions.py
new file mode 100644
index 0000000..e69de29
diff --git a/ellar_sqlalchemy/migrations/__init__.py b/ellar_sqlalchemy/migrations/__init__.py
new file mode 100644
index 0000000..34dad68
--- /dev/null
+++ b/ellar_sqlalchemy/migrations/__init__.py
@@ -0,0 +1,9 @@
+from .base import AlembicEnvMigrationBase
+from .multiple import MultipleDatabaseAlembicEnvMigration
+from .single import SingleDatabaseAlembicEnvMigration
+
+__all__ = [
+ "SingleDatabaseAlembicEnvMigration",
+ "MultipleDatabaseAlembicEnvMigration",
+ "AlembicEnvMigrationBase",
+]
diff --git a/ellar_sqlalchemy/migrations/base.py b/ellar_sqlalchemy/migrations/base.py
new file mode 100644
index 0000000..ff1a959
--- /dev/null
+++ b/ellar_sqlalchemy/migrations/base.py
@@ -0,0 +1,51 @@
+import typing as t
+from abc import abstractmethod
+
+from alembic.runtime.environment import NameFilterType
+from sqlalchemy.sql.schema import SchemaItem
+
+from ellar_sqlalchemy.services import EllarSQLAlchemyService
+from ellar_sqlalchemy.types import RevisionArgs
+
+if t.TYPE_CHECKING:
+ from alembic.operations import MigrationScript
+ from alembic.runtime.environment import EnvironmentContext
+ from alembic.runtime.migration import MigrationContext
+
+
+class AlembicEnvMigrationBase:
+ def __init__(self, db_service: EllarSQLAlchemyService) -> None:
+ self.db_service = db_service
+ self.use_two_phase = db_service.migration_options.use_two_phase
+
+ def include_object(
+ self,
+ obj: SchemaItem,
+ name: t.Optional[str],
+ type_: NameFilterType,
+ reflected: bool,
+ compare_to: t.Optional[SchemaItem],
+ ) -> bool:
+ # If you want to ignore things like these, set the following as a class attribute
+ # __table_args__ = {"info": {"skip_autogen": True}}
+ if obj.info.get("skip_autogen", False):
+ return False
+
+ return True
+
+ @abstractmethod
+ def default_process_revision_directives(
+ self,
+ context: "MigrationContext",
+ revision: RevisionArgs,
+ directives: t.List["MigrationScript"],
+ ) -> t.Any:
+ pass
+
+ @abstractmethod
+ def run_migrations_offline(self, context: "EnvironmentContext") -> None:
+ pass
+
+ @abstractmethod
+ async def run_migrations_online(self, context: "EnvironmentContext") -> None:
+ pass
diff --git a/ellar_sqlalchemy/migrations/multiple.py b/ellar_sqlalchemy/migrations/multiple.py
new file mode 100644
index 0000000..f899f26
--- /dev/null
+++ b/ellar_sqlalchemy/migrations/multiple.py
@@ -0,0 +1,212 @@
+import logging
+import typing as t
+from dataclasses import dataclass
+
+import sqlalchemy as sa
+from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine
+
+from ellar_sqlalchemy.model.database_binds import get_database_bind
+from ellar_sqlalchemy.types import RevisionArgs
+
+from .base import AlembicEnvMigrationBase
+
+if t.TYPE_CHECKING:
+ from alembic.operations import MigrationScript
+ from alembic.runtime.environment import EnvironmentContext
+ from alembic.runtime.migration import MigrationContext
+
+
+logger = logging.getLogger("alembic.env")
+
+
+@dataclass
+class DatabaseInfo:
+ name: str
+ metadata: sa.MetaData
+ engine: t.Union[sa.Engine, AsyncEngine]
+ connection: t.Union[sa.Connection, AsyncConnection]
+ use_two_phase: bool = False
+
+ _transaction: t.Optional[t.Union[sa.TwoPhaseTransaction, sa.RootTransaction]] = None
+ _sync_connection: t.Optional[sa.Connection] = None
+
+ def sync_connection(self) -> sa.Connection:
+ if not self._sync_connection:
+ self._sync_connection = getattr(
+ self.connection, "sync_connection", self.connection
+ )
+ assert self._sync_connection is not None
+ return self._sync_connection
+
+ def get_transactions(self) -> t.Union[sa.TwoPhaseTransaction, sa.RootTransaction]:
+ if not self._transaction:
+ if self.use_two_phase:
+ self._transaction = self.sync_connection().begin_twophase()
+ else:
+ self._transaction = self.sync_connection().begin()
+ assert self._transaction is not None
+ return self._transaction
+
+
+class MultipleDatabaseAlembicEnvMigration(AlembicEnvMigrationBase):
+ """
+ Migration Class for Multiple Database Configuration
+ for both asynchronous and synchronous database engine dialect
+ """
+
+ def default_process_revision_directives(
+ self,
+ context: "MigrationContext",
+ revision: RevisionArgs,
+ directives: t.List["MigrationScript"],
+ ) -> None:
+ if getattr(context.config.cmd_opts, "autogenerate", False):
+ script = directives[0]
+
+ if len(script.upgrade_ops_list) == len(self.db_service.engines.keys()):
+ # wait till there is a full check of all databases before removing empty operations
+
+ for upgrade_ops in list(script.upgrade_ops_list):
+ if upgrade_ops.is_empty():
+ script.upgrade_ops_list.remove(upgrade_ops)
+
+ for downgrade_ops in list(script.downgrade_ops_list):
+ if downgrade_ops.is_empty():
+ script.downgrade_ops_list.remove(downgrade_ops)
+
+ if (
+ len(script.upgrade_ops_list) == 0
+ and len(script.downgrade_ops_list) == 0
+ ):
+ directives[:] = []
+ logger.info("No changes in schema detected.")
+
+ def run_migrations_offline(self, context: "EnvironmentContext") -> None:
+ """Run migrations in 'offline' mode.
+
+ This configures the context with just a URL
+ and not an Engine, though an Engine is acceptable
+ here as well. By skipping the Engine creation,
+ we don't even need a DBAPI to be available.
+
+ Calls to context.execute() here emit the given string to the
+ script output.
+
+ """
+ # for --sql use case, run migrations for each URL into
+ # individual files.
+
+ for key, engine in self.db_service.engines.items():
+ logger.info("Migrating database %s" % key)
+
+ url = str(engine.url).replace("%", "%%")
+ metadata = get_database_bind(key, certain=True)
+
+ file_ = "%s.sql" % key
+ logger.info("Writing output to %s" % file_)
+ with open(file_, "w") as buffer:
+ context.configure(
+ url=url,
+ output_buffer=buffer,
+ target_metadata=metadata,
+ literal_binds=True,
+ dialect_opts={"paramstyle": "named"},
+ # If you want to ignore things like these, set the following as a class attribute
+ # __table_args__ = {"info": {"skip_autogen": True}}
+ include_object=self.include_object,
+ # detecting type changes
+ # compare_type=True,
+ )
+ with context.begin_transaction():
+ context.run_migrations(engine_name=key)
+
+ def _migration_action(
+ self, _: t.Any, db_infos: t.List[DatabaseInfo], context: "EnvironmentContext"
+ ) -> None:
+ # this callback is used to prevent an auto-migration from being generated
+ # when there are no changes to the schema
+ # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html
+ conf_args = {
+ "process_revision_directives": self.default_process_revision_directives
+ }
+ # conf_args = current_app.extensions['migrate'].configure_args
+ # if conf_args.get("process_revision_directives") is None:
+ # conf_args["process_revision_directives"] = process_revision_directives
+
+ try:
+ for db_info in db_infos:
+ context.configure(
+ connection=db_info.sync_connection(),
+ upgrade_token="%s_upgrades" % db_info.name,
+ downgrade_token="%s_downgrades" % db_info.name,
+ target_metadata=db_info.metadata,
+ **conf_args,
+ )
+
+ context.run_migrations(engine_name=db_info.name)
+
+ if self.use_two_phase:
+ for db_info in db_infos:
+ db_info.get_transactions().prepare() # type:ignore[attr-defined]
+
+ for db_info in db_infos:
+ db_info.get_transactions().commit()
+
+ except Exception as ex:
+ for db_info in db_infos:
+ db_info.get_transactions().rollback()
+
+ logger.error(ex)
+ raise ex
+ finally:
+ for db_info in db_infos:
+ db_info.sync_connection().close()
+
+ async def _check_if_coroutine(self, func: t.Any) -> t.Any:
+ if isinstance(func, t.Coroutine):
+ return await func
+ return func
+
+ async def _compute_engine_info(self) -> t.List[DatabaseInfo]:
+ res = []
+
+ for key, engine in self.db_service.engines.items():
+ metadata = get_database_bind(key, certain=True)
+
+ if engine.dialect.is_async:
+ async_engine = AsyncEngine(engine)
+ connection = async_engine.connect()
+ connection = await connection.start()
+ engine = async_engine # type:ignore[assignment]
+ else:
+ connection = engine.connect() # type:ignore[assignment]
+
+ database_info = DatabaseInfo(
+ engine=engine,
+ metadata=metadata,
+ connection=connection,
+ name=key,
+ use_two_phase=self.use_two_phase,
+ )
+ database_info.get_transactions()
+ res.append(database_info)
+ return res
+
+ async def run_migrations_online(self, context: "EnvironmentContext") -> None:
+ # for the direct-to-DB use case, start a transaction on all
+ # engines, then run all migrations, then commit all transactions.
+
+ database_infos = await self._compute_engine_info()
+ async_db_info_filter = [
+ db_info for db_info in database_infos if db_info.engine.dialect.is_async
+ ]
+ try:
+ if len(async_db_info_filter) > 0:
+ await async_db_info_filter[0].connection.run_sync(
+ self._migration_action, database_infos, context
+ )
+ else:
+ self._migration_action(None, database_infos, context)
+ finally:
+ for database_info_ in database_infos:
+ await self._check_if_coroutine(database_info_.connection.close())
diff --git a/ellar_sqlalchemy/migrations/single.py b/ellar_sqlalchemy/migrations/single.py
new file mode 100644
index 0000000..2923c66
--- /dev/null
+++ b/ellar_sqlalchemy/migrations/single.py
@@ -0,0 +1,113 @@
+import functools
+import logging
+import typing as t
+
+import sqlalchemy as sa
+from sqlalchemy.ext.asyncio import AsyncEngine
+
+from ellar_sqlalchemy.model.database_binds import get_database_bind
+from ellar_sqlalchemy.types import RevisionArgs
+
+from .base import AlembicEnvMigrationBase
+
+if t.TYPE_CHECKING:
+ from alembic.operations import MigrationScript
+ from alembic.runtime.environment import EnvironmentContext
+ from alembic.runtime.migration import MigrationContext
+
+
+logger = logging.getLogger("alembic.env")
+
+
+class SingleDatabaseAlembicEnvMigration(AlembicEnvMigrationBase):
+ """
+ Migration Class for a Single Database Configuration
+ for both asynchronous and synchronous database engine dialect
+ """
+
+ def default_process_revision_directives(
+ self,
+ context: "MigrationContext",
+ revision: RevisionArgs,
+ directives: t.List["MigrationScript"],
+ ) -> None:
+ if getattr(context.config.cmd_opts, "autogenerate", False):
+ script = directives[0]
+ if script.upgrade_ops.is_empty():
+ directives[:] = []
+ logger.info("No changes in schema detected.")
+
+ def run_migrations_offline(self, context: "EnvironmentContext") -> None:
+ """Run migrations in 'offline' mode.
+
+ This configures the context with just a URL
+ and not an Engine, though an Engine is acceptable
+ here as well. By skipping the Engine creation
+ we don't even need a DBAPI to be available.
+
+ Calls to context.execute() here emit the given string to the
+ script output.
+
+ """
+
+ key, engine = self.db_service.engines.popitem()
+ metadata = get_database_bind(key, certain=True)
+
+ context.configure(
+ url=str(engine.url).replace("%", "%%"),
+ target_metadata=metadata,
+ literal_binds=True,
+ dialect_opts={"paramstyle": "named"},
+ # If you want to ignore things like these, set the following as a class attribute
+ # __table_args__ = {"info": {"skip_autogen": True}}
+ include_object=self.include_object,
+ # detecting type changes
+ # compare_type=True,
+ )
+
+ with context.begin_transaction():
+ context.run_migrations()
+
+ def _migration_action(
+ self,
+ connection: sa.Connection,
+ metadata: sa.MetaData,
+ context: "EnvironmentContext",
+ ) -> None:
+ # this callback is used to prevent an auto-migration from being generated
+ # when there are no changes to the schema
+ # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html
+ conf_args = {
+ "process_revision_directives": self.default_process_revision_directives
+ }
+ # conf_args = current_app.extensions['migrate'].configure_args
+ # if conf_args.get("process_revision_directives") is None:
+ # conf_args["process_revision_directives"] = process_revision_directives
+
+ context.configure(connection=connection, target_metadata=metadata, **conf_args)
+
+ with context.begin_transaction():
+ context.run_migrations()
+
+ async def run_migrations_online(self, context: "EnvironmentContext") -> None:
+ """Run migrations in 'online' mode.
+
+ In this scenario we need to create an Engine
+ and associate a connection with the context.
+
+ """
+
+ key, engine = self.db_service.engines.popitem()
+ metadata = get_database_bind(key, certain=True)
+
+ migration_action_partial = functools.partial(
+ self._migration_action, metadata=metadata, context=context
+ )
+
+ if engine.dialect.is_async:
+ async_engine = AsyncEngine(engine)
+ async with async_engine.connect() as connection:
+ await connection.run_sync(migration_action_partial)
+ else:
+ with engine.connect() as connection:
+ migration_action_partial(connection)
diff --git a/ellar_sqlalchemy/model/__init__.py b/ellar_sqlalchemy/model/__init__.py
new file mode 100644
index 0000000..5225e55
--- /dev/null
+++ b/ellar_sqlalchemy/model/__init__.py
@@ -0,0 +1,10 @@
+from .base import Model
+from .typeDecorator import GUID, GenericIP
+from .utils import make_metadata
+
+__all__ = [
+ "Model",
+ "make_metadata",
+ "GUID",
+ "GenericIP",
+]
diff --git a/ellar_sqlalchemy/model/base.py b/ellar_sqlalchemy/model/base.py
new file mode 100644
index 0000000..b5bd940
--- /dev/null
+++ b/ellar_sqlalchemy/model/base.py
@@ -0,0 +1,121 @@
+import types
+import typing as t
+
+import sqlalchemy.orm as sa_orm
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import DeclarativeBase
+
+from ellar_sqlalchemy.constant import (
+ DATABASE_BIND_KEY,
+ DEFAULT_KEY,
+)
+
+from .database_binds import get_database_bind, has_database_bind, update_database_binds
+from .mixins import (
+ DatabaseBindKeyMixin,
+ ModelDataExportMixin,
+ ModelTrackMixin,
+ NameMetaMixin,
+)
+
+SQLAlchemyDefaultBase = None
+
+
+def _model_as_base(
+ name: str, bases: t.Tuple[t.Any, ...], namespace: t.Dict[str, t.Any]
+) -> t.Type["Model"]:
+ global SQLAlchemyDefaultBase
+
+ if SQLAlchemyDefaultBase is None:
+ declarative_bases = [
+ b
+ for b in bases
+ if issubclass(b, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta))
+ ]
+
+ def get_session(cls: t.Type[Model]) -> None:
+ raise Exception("EllarSQLAlchemyService is not ready")
+
+ namespace.update(
+ get_db_session=getattr(Model, "get_session", classmethod(get_session)),
+ skip_default_base_check=True,
+ )
+
+ model = types.new_class(
+ f"{name}",
+ (
+ DatabaseBindKeyMixin,
+ NameMetaMixin,
+ ModelTrackMixin,
+ ModelDataExportMixin,
+ Model,
+ *declarative_bases,
+ sa_orm.DeclarativeBase,
+ ),
+ {},
+ lambda ns: ns.update(namespace),
+ )
+ model = t.cast(t.Type[Model], model)
+ SQLAlchemyDefaultBase = model
+
+ if not has_database_bind(DEFAULT_KEY):
+ # Use the model's metadata as the default metadata.
+ model.metadata.info[DATABASE_BIND_KEY] = DEFAULT_KEY
+ update_database_binds(DEFAULT_KEY, model.metadata)
+ else:
+ # Use the passed in default metadata as the model's metadata.
+ model.metadata = get_database_bind(DEFAULT_KEY, certain=True)
+ return model
+ else:
+ return SQLAlchemyDefaultBase
+
+
+class ModelMeta(type(DeclarativeBase)): # type:ignore[misc]
+ def __new__(
+ mcs,
+ name: str,
+ bases: t.Tuple[t.Any, ...],
+ namespace: t.Dict[str, t.Any],
+ **kwargs: t.Any,
+ ) -> t.Type[t.Union["Model", t.Any]]:
+ if bases == () and name == "Model":
+ return type.__new__(mcs, name, tuple(bases), namespace, **kwargs)
+
+ if "as_base" in kwargs:
+ return _model_as_base(name, bases, namespace)
+
+ _bases = list(bases)
+
+ skip_default_base_check = False
+ if "skip_default_base_check" in namespace:
+ skip_default_base_check = namespace.pop("skip_default_base_check")
+
+ if not skip_default_base_check:
+ if SQLAlchemyDefaultBase is None:
+ raise Exception(
+ "EllarSQLAlchemy Default Declarative Base has not been configured."
+ "\nPlease call `configure_model_declarative_base` before ORM Model construction"
+ " or Use EllarSQLAlchemy Service"
+ )
+ elif SQLAlchemyDefaultBase and SQLAlchemyDefaultBase not in _bases:
+ _bases = [SQLAlchemyDefaultBase, *_bases]
+
+ return super().__new__(mcs, name, (*_bases,), namespace, **kwargs) # type:ignore[no-any-return]
+
+
+class Model(metaclass=ModelMeta):
+ __database__: str = "default"
+
+ if t.TYPE_CHECKING:
+
+ def __init__(self, **kwargs: t.Any) -> None:
+ ...
+
+ @classmethod
+ def get_db_session(
+ cls,
+ ) -> t.Union[sa_orm.Session, AsyncSession, t.Any]:
+ ...
+
+ def dict(self, exclude: t.Optional[t.Set[str]] = None) -> t.Dict[str, t.Any]:
+ ...
diff --git a/ellar_sqlalchemy/model/database_binds.py b/ellar_sqlalchemy/model/database_binds.py
new file mode 100644
index 0000000..a55c32a
--- /dev/null
+++ b/ellar_sqlalchemy/model/database_binds.py
@@ -0,0 +1,23 @@
+import typing as t
+
+import sqlalchemy as sa
+
+__model_database_metadata__: t.Dict[str, sa.MetaData] = {}
+
+
+def update_database_binds(key: str, value: sa.MetaData) -> None:
+ __model_database_metadata__[key] = value
+
+
+def get_database_binds() -> t.Dict[str, sa.MetaData]:
+ return __model_database_metadata__.copy()
+
+
+def get_database_bind(key: str, certain: bool = False) -> sa.MetaData:
+ if certain:
+ return __model_database_metadata__[key]
+ return __model_database_metadata__.get(key) # type:ignore[return-value]
+
+
+def has_database_bind(key: str) -> bool:
+ return key in __model_database_metadata__
diff --git a/ellar_sqlalchemy/model/mixins.py b/ellar_sqlalchemy/model/mixins.py
new file mode 100644
index 0000000..ded7d01
--- /dev/null
+++ b/ellar_sqlalchemy/model/mixins.py
@@ -0,0 +1,79 @@
+import typing as t
+
+import sqlalchemy as sa
+
+from ellar_sqlalchemy.constant import ABSTRACT_KEY, DATABASE_KEY, DEFAULT_KEY, TABLE_KEY
+
+from .utils import camel_to_snake_case, make_metadata, should_set_table_name
+
+if t.TYPE_CHECKING:
+ from .base import Model
+
+__ellar_sqlalchemy_models__: t.Dict[str, t.Type["Model"]] = {}
+
+
+def get_registered_models() -> t.Dict[str, t.Type["Model"]]:
+ return __ellar_sqlalchemy_models__.copy()
+
+
+class NameMetaMixin:
+ metadata: sa.MetaData
+ __tablename__: str
+ __table__: sa.Table
+
+ def __init_subclass__(cls, **kwargs: t.Dict[str, t.Any]) -> None:
+ if should_set_table_name(cls):
+ cls.__tablename__ = camel_to_snake_case(cls.__name__)
+
+ super().__init_subclass__(**kwargs)
+
+
+class DatabaseBindKeyMixin:
+ metadata: sa.MetaData
+ __dnd__ = "Ellar"
+
+ def __init_subclass__(cls, **kwargs: t.Dict[str, t.Any]) -> None:
+ if not ("metadata" in cls.__dict__ or TABLE_KEY in cls.__dict__) and hasattr(
+ cls, DATABASE_KEY
+ ):
+ database_bind_key = getattr(cls, DATABASE_KEY, DEFAULT_KEY)
+ parent_metadata = getattr(cls, "metadata", None)
+ metadata = make_metadata(database_bind_key)
+
+ if metadata is not parent_metadata:
+ cls.metadata = metadata
+
+ super().__init_subclass__(**kwargs)
+
+
+class ModelTrackMixin:
+ metadata: sa.MetaData
+
+ def __init_subclass__(cls, **kwargs: t.Dict[str, t.Any]) -> None:
+ super().__init_subclass__(**kwargs)
+
+ if TABLE_KEY in cls.__dict__ and ABSTRACT_KEY not in cls.__dict__:
+ __ellar_sqlalchemy_models__[str(cls)] = cls # type:ignore[assignment]
+
+
+class ModelDataExportMixin:
+ def __repr__(self) -> str:
+ columns = ", ".join(
+ [
+ f"{k}={repr(v)}"
+ for k, v in self.__dict__.items()
+ if not k.startswith("_")
+ ]
+ )
+ return f"<{self.__class__.__name__}({columns})>"
+
+ def dict(self, exclude: t.Optional[t.Set[str]] = None) -> t.Dict[str, t.Any]:
+ # TODO: implement advance exclude and include that goes deep into relationships too
+ _exclude: t.Set[str] = set() if not exclude else exclude
+
+ tuple_generator = (
+ (k, v)
+ for k, v in self.__dict__.items()
+ if k not in _exclude and not k.startswith("_sa")
+ )
+ return dict(tuple_generator)
diff --git a/ellar_sqlalchemy/model/typeDecorator/__init__.py b/ellar_sqlalchemy/model/typeDecorator/__init__.py
new file mode 100644
index 0000000..daddaa9
--- /dev/null
+++ b/ellar_sqlalchemy/model/typeDecorator/__init__.py
@@ -0,0 +1,13 @@
+# from .file import FileField
+from .guid import GUID
+from .ipaddress import GenericIP
+
+# from .image import CroppingDetails, ImageFileField
+
+__all__ = [
+ "GUID",
+ "GenericIP",
+ # "CroppingDetails",
+ # "FileField",
+ # "ImageFileField",
+]
diff --git a/ellar_sqlalchemy/model/typeDecorator/exceptions.py.ellar b/ellar_sqlalchemy/model/typeDecorator/exceptions.py.ellar
new file mode 100644
index 0000000..779b32a
--- /dev/null
+++ b/ellar_sqlalchemy/model/typeDecorator/exceptions.py.ellar
@@ -0,0 +1,25 @@
+class ContentTypeValidationError(Exception):
+ def __init__(self, content_type=None, valid_content_types=None):
+
+ if content_type is None:
+ message = "Content type is not provided. "
+ else:
+ message = "Content type is not supported %s. " % content_type
+
+ if valid_content_types:
+ message += "Valid options are: %s" % ", ".join(valid_content_types)
+
+ super().__init__(message)
+
+
+class InvalidFileError(Exception):
+ pass
+
+
+class InvalidImageOperationError(Exception):
+ pass
+
+
+class MaximumAllowedFileLengthError(Exception):
+ def __init__(self, max_length: int):
+ super().__init__("Cannot store files larger than: %d bytes" % max_length)
diff --git a/ellar_sqlalchemy/model/typeDecorator/file.py.ellar b/ellar_sqlalchemy/model/typeDecorator/file.py.ellar
new file mode 100644
index 0000000..4a8a83a
--- /dev/null
+++ b/ellar_sqlalchemy/model/typeDecorator/file.py.ellar
@@ -0,0 +1,208 @@
+import json
+import time
+import typing as t
+import uuid
+
+from sqlalchemy import JSON, String, TypeDecorator
+from starlette.datastructures import UploadFile
+
+from fullview_trader.core.storage import BaseStorage
+from fullview_trader.core.storage.utils import get_length, get_valid_filename
+
+from .exceptions import (
+ ContentTypeValidationError,
+ InvalidFileError,
+ MaximumAllowedFileLengthError,
+)
+from .mimetypes import guess_extension, magic_mime_from_buffer
+
+T = t.TypeVar("T", bound="FileObject")
+
+
+class FileObject:
+ def __init__(
+ self,
+ *,
+ storage: BaseStorage,
+ original_filename: str,
+ uploaded_on: int,
+ content_type: str,
+ saved_filename: str,
+ extension: str,
+ file_size: int,
+ ) -> None:
+ self._storage = storage
+ self.original_filename = original_filename
+ self.uploaded_on = uploaded_on
+ self.content_type = content_type
+ self.filename = saved_filename
+ self.extension = extension
+ self.file_size = file_size
+
+ def locate(self) -> str:
+ return self._storage.locate(self.filename)
+
+ def open(self) -> t.IO:
+ return self._storage.open(self.filename)
+
+ def to_dict(self) -> dict:
+ return {
+ "original_filename": self.original_filename,
+ "uploaded_on": self.uploaded_on,
+ "content_type": self.content_type,
+ "extension": self.extension,
+ "file_size": self.file_size,
+ "saved_filename": self.filename,
+ "service_name": self._storage.service_name(),
+ }
+
+ def __str__(self) -> str:
+ return f"filename={self.filename}, content_type={self.content_type}, file_size={self.file_size}"
+
+ def __repr__(self) -> str:
+ return str(self)
+
+
+class FileFieldBase(t.Generic[T]):
+ FileObject: t.Type[T] = FileObject
+
+ def load_dialect_impl(self, dialect):
+ if dialect.name == "sqlite":
+ return dialect.type_descriptor(String())
+ else:
+ return dialect.type_descriptor(JSON())
+
+ def __init__(
+ self,
+ *args: t.Any,
+ storage: BaseStorage = None,
+ allowed_content_types: t.List[str] = None,
+ max_size: t.Optional[int] = None,
+ **kwargs: t.Any,
+ ):
+ if allowed_content_types is None:
+ allowed_content_types = []
+ super().__init__(*args, **kwargs)
+
+ self._storage = storage
+ self._allowed_content_types = allowed_content_types
+ self._max_size = max_size
+
+ def validate(self, file: T) -> None:
+ if self._allowed_content_types and file.content_type not in self._allowed_content_types:
+ raise ContentTypeValidationError(file.content_type, self._allowed_content_types)
+ if self._max_size and file.file_size > self._max_size:
+ raise MaximumAllowedFileLengthError(self._max_size)
+
+ self._storage.validate_file_name(file.filename)
+
+ def load_from_str(self, data: str) -> T:
+ data_dict = t.cast(t.Dict, json.loads(data))
+ return self.load(data_dict)
+
+ def load(self, data: dict) -> T:
+ if "service_name" in data:
+ data.pop("service_name")
+ return self.FileObject(storage=self._storage, **data)
+
+ def _guess_content_type(self, file: t.IO) -> str:
+ content = file.read(1024)
+
+ if isinstance(content, str):
+ content = str.encode(content)
+
+ file.seek(0)
+
+ return magic_mime_from_buffer(content)
+
+ def get_extra_file_initialization_context(self, file: UploadFile) -> dict:
+ return {}
+
+ def convert_to_file_object(self, file: UploadFile) -> T:
+ unique_name = str(uuid.uuid4())
+
+ original_filename = file.filename
+
+ # use python magic to get the content type
+ content_type = self._guess_content_type(file.file)
+ extension = guess_extension(content_type)
+
+ file_size = get_length(file.file)
+ saved_filename = f"{original_filename[:-len(extension)]}_{unique_name[:-8]}{extension}"
+ saved_filename = get_valid_filename(saved_filename)
+
+ init_kwargs = self.get_extra_file_initialization_context(file)
+ init_kwargs.update(
+ storage=self._storage,
+ original_filename=original_filename,
+ uploaded_on=int(time.time()),
+ content_type=content_type,
+ extension=extension,
+ file_size=file_size,
+ saved_filename=saved_filename,
+ )
+ return self.FileObject(**init_kwargs)
+
+ def process_bind_param_action(
+ self, value: t.Any, dialect: t.Any
+ ) -> t.Optional[t.Union[str, dict]]:
+ if value is None:
+ return value
+
+ if isinstance(value, UploadFile):
+ value.file.seek(0) # make sure we are always at the beginning
+ file_obj = self.convert_to_file_object(value)
+ self.validate(file_obj)
+
+ self._storage.put(file_obj.filename, value.file)
+ value = file_obj
+
+ if isinstance(value, FileObject):
+ if dialect.name == "sqlite":
+ return json.dumps(value.to_dict())
+ return value.to_dict()
+
+ raise InvalidFileError()
+
+ def process_result_value_action(
+ self, value: t.Any, dialect: t.Any
+ ) -> t.Optional[t.Union[str, dict]]:
+ if value is None:
+ return value
+ else:
+ if isinstance(value, str):
+ value = self.load_from_str(value)
+ elif isinstance(value, dict):
+ value = self.load(value)
+ return value
+
+
+class FileField(FileFieldBase[FileObject], TypeDecorator):
+ """
+ Provide SqlAlchemy TypeDecorator for saving files
+ ## Basic Usage
+
+ fs = FileSystemStorage('path/to/save/files')
+
+ class MyTable(Base):
+ image: FileField.FileObject = sa.Column(
+ ImageFileField(storage=fs, max_size=10*MB, allowed_content_type=["application/pdf"]),
+ nullable=True
+ )
+
+ def route(file: File[UploadFile]):
+ session = SessionLocal()
+ my_table_model = MyTable(image=file)
+ session.add(my_table_model)
+ session.commit()
+ return my_table_model.image.to_dict()
+
+ """
+
+ impl = JSON
+
+ def process_bind_param(self, value, dialect):
+ return self.process_bind_param_action(value, dialect)
+
+ def process_result_value(self, value, dialect):
+ return self.process_result_value_action(value, dialect)
diff --git a/ellar_sqlalchemy/model/typeDecorator/guid.py b/ellar_sqlalchemy/model/typeDecorator/guid.py
new file mode 100644
index 0000000..e745753
--- /dev/null
+++ b/ellar_sqlalchemy/model/typeDecorator/guid.py
@@ -0,0 +1,47 @@
+import typing as t
+import uuid
+
+import sqlalchemy as sa
+from sqlalchemy.dialects.postgresql import UUID
+from sqlalchemy.types import CHAR
+
+
+class GUID(sa.TypeDecorator): # type: ignore[type-arg]
+ """Platform-independent GUID type.
+
+ Uses PostgreSQL's UUID type, otherwise uses
+ CHAR(32), storing as stringified hex values.
+
+ """
+
+ impl = CHAR
+
+ def load_dialect_impl(self, dialect: sa.Dialect) -> t.Any:
+ if dialect.name == "postgresql":
+ return dialect.type_descriptor(UUID())
+ else:
+ return dialect.type_descriptor(CHAR(32))
+
+ def process_bind_param(
+ self, value: t.Optional[t.Any], dialect: sa.Dialect
+ ) -> t.Any:
+ if value is None:
+ return value
+ elif dialect.name == "postgresql":
+ return str(value)
+ else:
+ if not isinstance(value, uuid.UUID):
+ return "%.32x" % uuid.UUID(value).int
+ else:
+ # hexstring
+ return "%.32x" % value.int
+
+ def process_result_value(
+ self, value: t.Optional[t.Any], dialect: sa.Dialect
+ ) -> t.Any:
+ if value is None:
+ return value
+ else:
+ if not isinstance(value, uuid.UUID):
+ value = uuid.UUID(value)
+ return value
diff --git a/ellar_sqlalchemy/model/typeDecorator/image.py.ellar b/ellar_sqlalchemy/model/typeDecorator/image.py.ellar
new file mode 100644
index 0000000..b8b298c
--- /dev/null
+++ b/ellar_sqlalchemy/model/typeDecorator/image.py.ellar
@@ -0,0 +1,151 @@
+import typing as t
+from dataclasses import dataclass
+from io import SEEK_END, BytesIO
+
+from sqlalchemy import JSON, TypeDecorator
+from starlette.datastructures import UploadFile
+
+from .exceptions import InvalidImageOperationError
+
+try:
+ from PIL import Image
+except ImportError as im_ex: # pragma: no cover
+ raise Exception("Pillow package is required. Use `pip install Pillow`.") from im_ex
+
+from fullview_trader.core.storage import BaseStorage
+
+from .file import FileFieldBase, FileObject
+
+
+@dataclass
+class CroppingDetails:
+ x: int
+ y: int
+ height: int
+ width: int
+
+
+class ImageFileObject(FileObject):
+ def __init__(self, *, height: float, width: float, **kwargs: t.Any) -> None:
+ super().__init__(**kwargs)
+ self.height = height
+ self.width = width
+
+ def to_dict(self) -> dict:
+ data = super().to_dict()
+ data.update(height=self.height, width=self.width)
+ return data
+
+
+class ImageFileField(FileFieldBase[ImageFileObject], TypeDecorator):
+ """
+ Provide SqlAlchemy TypeDecorator for Image files
+ ## Basic Usage
+
+ class MyTable(Base):
+ image:
+ ImageFileField.FileObject = sa.Column(ImageFileField(storage=FileSystemStorage('path/to/save/files',
+ max_size=10*MB), nullable=True)
+
+ def route(file: File[UploadFile]):
+ session = SessionLocal()
+ my_table_model = MyTable(image=file)
+ session.add(my_table_model)
+ session.commit()
+ return my_table_model.image.to_dict()
+
+ ## Cropping
+ Image file also provides cropping capabilities which can be defined in the column or when saving the image data.
+
+ fs = FileSystemStorage('path/to/save/files')
+ class MyTable(Base):
+ image = sa.Column(ImageFileField(storage=fs, crop=CroppingDetails(x=100, y=200, height=400, width=400)), nullable=True)
+
+ OR
+ def route(file: File[UploadFile]):
+ session = SessionLocal()
+ my_table_model = MyTable(
+ image=(file, CroppingDetails(x=100, y=200, height=400, width=400)),
+ )
+
+ """
+
+ impl = JSON
+ FileObject = ImageFileObject
+
+ def __init__(
+ self,
+ *args: t.Any,
+ storage: BaseStorage,
+ max_size: t.Optional[int] = None,
+ crop: t.Optional[CroppingDetails] = None,
+ **kwargs: t.Any
+ ):
+ kwargs.setdefault("allowed_content_types", ["image/jpeg", "image/png"])
+ super().__init__(*args, storage=storage, max_size=max_size, **kwargs)
+ self.crop = crop
+
+ def process_bind_param(self, value, dialect):
+ return self.process_bind_param_action(value, dialect)
+
+ def process_result_value(self, value, dialect):
+ return self.process_result_value_action(value, dialect)
+
+ def get_extra_file_initialization_context(self, file: UploadFile) -> dict:
+ with Image.open(file.file) as image:
+ width, height = image.size
+ return {"width": width, "height": height}
+
+ def crop_image_with_box_sizing(
+ self, file: UploadFile, crop: t.Optional[CroppingDetails] = None
+ ) -> UploadFile:
+ crop_info = crop or self.crop
+ img = Image.open(file.file)
+ (height, width, x, y,) = (
+ crop_info.height,
+ crop_info.width,
+ crop_info.x,
+ crop_info.y,
+ )
+ left = x
+ top = y
+ right = x + width
+ bottom = y + height
+
+ crop_box = (left, top, right, bottom)
+
+ img_res = img.crop(box=crop_box)
+ temp_thumb = BytesIO()
+ img_res.save(temp_thumb, img.format)
+ # Go to the end of the stream.
+ temp_thumb.seek(0, SEEK_END)
+
+ # Get the current position, which is now at the end.
+ # We can use this as the size.
+ size = temp_thumb.tell()
+ temp_thumb.seek(0)
+
+ content = UploadFile(
+ file=temp_thumb, filename=file.filename, size=size, headers=file.headers
+ )
+ return content
+
+ def process_bind_param_action(
+ self, value: t.Any, dialect: t.Any
+ ) -> t.Optional[t.Union[str, dict]]:
+ if isinstance(value, tuple):
+ file, crop_data = value
+ if not isinstance(file, UploadFile) or not isinstance(crop_data, CroppingDetails):
+ raise InvalidImageOperationError(
+ "Invalid data was provided for ImageFileField. "
+ "Accept values: UploadFile or (UploadFile, CroppingDetails)"
+ )
+ new_file = self.crop_image_with_box_sizing(file=file, crop=crop_data)
+ return super().process_bind_param_action(new_file, dialect)
+
+ if isinstance(value, UploadFile):
+ if self.crop:
+ return super().process_bind_param_action(
+ self.crop_image_with_box_sizing(value), dialect
+ )
+ return super().process_bind_param_action(value, dialect)
diff --git a/ellar_sqlalchemy/model/typeDecorator/ipaddress.py b/ellar_sqlalchemy/model/typeDecorator/ipaddress.py
new file mode 100644
index 0000000..8c29f91
--- /dev/null
+++ b/ellar_sqlalchemy/model/typeDecorator/ipaddress.py
@@ -0,0 +1,39 @@
+import ipaddress
+import typing as t
+
+import sqlalchemy as sa
+import sqlalchemy.dialects as sa_dialects
+
+
+class GenericIP(sa.TypeDecorator): # type:ignore[type-arg]
+ """
+ Platform-independent IP Address type.
+
+ Uses PostgreSQL's INET type, otherwise uses
+ CHAR(45), storing as stringified values.
+ """
+
+ impl = sa.CHAR
+ cache_ok = True
+
+ def load_dialect_impl(self, dialect: sa.Dialect) -> t.Any:
+ if dialect.name == "postgresql":
+ return dialect.type_descriptor(sa_dialects.postgresql.INET()) # type:ignore[attr-defined]
+ else:
+ return dialect.type_descriptor(sa.CHAR(45))
+
+ def process_bind_param(
+ self, value: t.Optional[t.Any], dialect: sa.Dialect
+ ) -> t.Any:
+ if value is not None:
+ return str(value)
+
+ def process_result_value(
+ self, value: t.Optional[t.Any], dialect: sa.Dialect
+ ) -> t.Any:
+ if value is None:
+ return value
+
+ if not isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
+ value = ipaddress.ip_address(value)
+ return value
diff --git a/ellar_sqlalchemy/model/typeDecorator/mimetypes.py.ellar b/ellar_sqlalchemy/model/typeDecorator/mimetypes.py.ellar
new file mode 100644
index 0000000..f287a76
--- /dev/null
+++ b/ellar_sqlalchemy/model/typeDecorator/mimetypes.py.ellar
@@ -0,0 +1,21 @@
+import mimetypes as mdb
+import typing
+
+import magic
+
+
+def magic_mime_from_buffer(buffer: bytes) -> str:
+ return magic.from_buffer(buffer, mime=True)
+
+
+def guess_extension(mimetype: str) -> typing.Optional[str]:
+ """
+ Due to the python bugs 'image/jpeg' overridden:
+ - https://bugs.python.org/issue4963
+ - https://bugs.python.org/issue1043134
+ - https://bugs.python.org/issue6626#msg91205
+ """
+
+ if mimetype == "image/jpeg":
+ return ".jpeg"
+ return mdb.guess_extension(mimetype)
diff --git a/ellar_sqlalchemy/model/utils.py b/ellar_sqlalchemy/model/utils.py
new file mode 100644
index 0000000..5e5a796
--- /dev/null
+++ b/ellar_sqlalchemy/model/utils.py
@@ -0,0 +1,65 @@
+import re
+
+import sqlalchemy as sa
+import sqlalchemy.orm as sa_orm
+
+from ellar_sqlalchemy.constant import DATABASE_BIND_KEY, DEFAULT_KEY
+
+from .database_binds import get_database_bind, has_database_bind, update_database_binds
+
+
+def make_metadata(database_key: str) -> sa.MetaData:
+ if has_database_bind(database_key):
+ return get_database_bind(database_key, certain=True)
+
+ if database_key is not None:
+ # Copy the naming convention from the default metadata.
+ naming_convention = make_metadata(DEFAULT_KEY).naming_convention
+ else:
+ naming_convention = None
+
+ # Set the bind key in info to be used by session.get_bind.
+ metadata = sa.MetaData(
+ naming_convention=naming_convention, info={DATABASE_BIND_KEY: database_key}
+ )
+ update_database_binds(database_key, metadata)
+ return metadata
+
+
+def camel_to_snake_case(name: str) -> str:
+ """Convert a ``CamelCase`` name to ``snake_case``."""
+ name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name)
+ return name.lower().lstrip("_")
+
+
+def should_set_table_name(cls: type) -> bool:
+ if (
+ cls.__dict__.get("__abstract__", False)
+ or (
+ not issubclass(cls, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta))
+ and not any(isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:])
+ )
+ or any(
+ (b is sa_orm.DeclarativeBase or b is sa_orm.DeclarativeBaseNoMeta)
+ for b in cls.__bases__
+ )
+ ):
+ return False
+
+ for base in cls.__mro__:
+ if "__tablename__" not in base.__dict__:
+ continue
+
+ if isinstance(base.__dict__["__tablename__"], sa_orm.declared_attr):
+ return False
+
+ return not (
+ base is cls
+ or base.__dict__.get("__abstract__", False)
+ or not (
+ isinstance(base, sa_orm.decl_api.DeclarativeAttributeIntercept)
+ or issubclass(base, sa_orm.DeclarativeBaseNoMeta)
+ )
+ )
+
+ return True
diff --git a/ellar_sqlalchemy/module.py b/ellar_sqlalchemy/module.py
new file mode 100644
index 0000000..20f69d5
--- /dev/null
+++ b/ellar_sqlalchemy/module.py
@@ -0,0 +1,146 @@
+import functools
+import typing as t
+
+import sqlalchemy as sa
+from ellar.app import current_injector
+from ellar.common import IApplicationShutdown, IModuleSetup, Module
+from ellar.common.utils.importer import get_main_directory_by_stack
+from ellar.core import Config, DynamicModule, ModuleBase, ModuleSetup
+from ellar.di import ProviderConfig, request_scope
+from sqlalchemy.ext.asyncio import (
+ AsyncEngine,
+ AsyncSession,
+)
+from sqlalchemy.orm import Session
+
+from ellar_sqlalchemy.services import EllarSQLAlchemyService
+
+from .cli import DBCommands
+from .schemas import MigrationOption, SQLAlchemyConfig
+
+
+@Module(commands=[DBCommands])
+class EllarSQLAlchemyModule(ModuleBase, IModuleSetup, IApplicationShutdown):
+ async def on_shutdown(self) -> None:
+ db_service = current_injector.get(EllarSQLAlchemyService)
+ db_service.session_factory.remove()
+
+ @classmethod
+ def setup(
+ cls,
+ *,
+ databases: t.Union[str, t.Dict[str, t.Any]],
+ migration_options: t.Union[t.Dict[str, t.Any], MigrationOption],
+ session_options: t.Optional[t.Dict[str, t.Any]] = None,
+ engine_options: t.Optional[t.Dict[str, t.Any]] = None,
+ models: t.Optional[t.List[str]] = None,
+ echo: bool = False,
+ ) -> "DynamicModule":
+ """
+ Configures EllarSQLAlchemyModule and setup required providers.
+ """
+ root_path = get_main_directory_by_stack("__main__", stack_level=2)
+ if isinstance(migration_options, dict):
+ migration_options.update(
+ directory=get_main_directory_by_stack(
+ migration_options.get("directory", "__main__/migrations"),
+ stack_level=2,
+ from_dir=root_path,
+ )
+ )
+ if isinstance(migration_options, MigrationOption):
+ migration_options.directory = get_main_directory_by_stack(
+ migration_options.directory, stack_level=2, from_dir=root_path
+ )
+ migration_options = migration_options.dict()
+
+ schema = SQLAlchemyConfig.model_validate(
+ {
+ "databases": databases,
+ "engine_options": engine_options,
+ "echo": echo,
+ "models": models,
+ "migration_options": migration_options,
+ "root_path": root_path,
+ },
+ from_attributes=True,
+ )
+ return cls.__setup_module(schema)
+
+ @classmethod
+ def __setup_module(cls, sql_alchemy_config: SQLAlchemyConfig) -> DynamicModule:
+ db_service = EllarSQLAlchemyService(
+ databases=sql_alchemy_config.databases,
+ common_engine_options=sql_alchemy_config.engine_options,
+ common_session_options=sql_alchemy_config.session_options,
+ echo=sql_alchemy_config.echo,
+ models=sql_alchemy_config.models,
+ root_path=sql_alchemy_config.root_path,
+ migration_options=sql_alchemy_config.migration_options,
+ )
+ providers: t.List[t.Any] = []
+
+ if db_service._async_session_type:
+ providers.append(ProviderConfig(AsyncEngine, use_value=db_service.engine))
+ providers.append(
+ ProviderConfig(
+ AsyncSession,
+ use_value=lambda: db_service.session_factory(),
+ scope=request_scope,
+ )
+ )
+ else:
+ providers.append(ProviderConfig(sa.Engine, use_value=db_service.engine))
+ providers.append(
+ ProviderConfig(
+ Session,
+ use_value=lambda: db_service.session_factory(),
+ scope=request_scope,
+ )
+ )
+
+ providers.append(ProviderConfig(EllarSQLAlchemyService, use_value=db_service))
+ return DynamicModule(
+ cls,
+ providers=providers,
+ )
+
+ @classmethod
+ def register_setup(cls, **override_config: t.Any) -> ModuleSetup:
+ """
+ Register Module to be configured through `SQLALCHEMY_CONFIG` variable in Application Config
+ """
+ root_path = get_main_directory_by_stack("__main__", stack_level=2)
+ return ModuleSetup(
+ cls,
+ inject=[Config],
+ factory=functools.partial(
+ cls.__register_setup_factory,
+ root_path=root_path,
+ override_config=override_config,
+ ),
+ )
+
+ @staticmethod
+ def __register_setup_factory(
+ module: t.Type["EllarSQLAlchemyModule"],
+ config: Config,
+ root_path: str,
+ override_config: t.Dict[str, t.Any],
+ ) -> DynamicModule:
+ if config.get("SQLALCHEMY_CONFIG") and isinstance(
+ config.SQLALCHEMY_CONFIG, dict
+ ):
+ defined_config = dict(config.SQLALCHEMY_CONFIG, root_path=root_path)
+ defined_config.update(override_config)
+
+ schema = SQLAlchemyConfig.model_validate(
+ defined_config, from_attributes=True
+ )
+
+ schema.migration_options.directory = get_main_directory_by_stack(
+ schema.migration_options.directory, stack_level=2, from_dir=root_path
+ )
+
+ return module.__setup_module(schema)
+ raise RuntimeError("Could not find `SQLALCHEMY_CONFIG` in application config.")
diff --git a/ellar_sqlalchemy/py.typed b/ellar_sqlalchemy/py.typed
new file mode 100644
index 0000000..e69de29
diff --git a/ellar_sqlalchemy/schemas.py b/ellar_sqlalchemy/schemas.py
new file mode 100644
index 0000000..b6752da
--- /dev/null
+++ b/ellar_sqlalchemy/schemas.py
@@ -0,0 +1,36 @@
+import typing as t
+from dataclasses import asdict, dataclass, field
+
+import ellar.common as ecm
+
+from ellar_sqlalchemy.types import RevisionDirectiveCallable
+
+
+@dataclass
+class MigrationOption:
+ directory: str
+ revision_directives_callbacks: t.List[RevisionDirectiveCallable] = field(
+ default_factory=lambda: []
+ )
+ use_two_phase: bool = False
+
+ def dict(self) -> t.Dict[str, t.Any]:
+ return asdict(self)
+
+
+class SQLAlchemyConfig(ecm.Serializer):
+ model_config = {"arbitrary_types_allowed": True}
+
+ databases: t.Union[str, t.Dict[str, t.Any]]
+ migration_options: MigrationOption
+ root_path: str
+
+ session_options: t.Dict[str, t.Any] = {
+ "autoflush": False,
+ "future": True,
+ "expire_on_commit": False,
+ }
+ echo: bool = False
+ engine_options: t.Optional[t.Dict[str, t.Any]] = None
+
+ models: t.Optional[t.List[str]] = None
diff --git a/ellar_sqlalchemy/services/__init__.py b/ellar_sqlalchemy/services/__init__.py
new file mode 100644
index 0000000..964e556
--- /dev/null
+++ b/ellar_sqlalchemy/services/__init__.py
@@ -0,0 +1,3 @@
+from .base import EllarSQLAlchemyService
+
+__all__ = ["EllarSQLAlchemyService"]
diff --git a/ellar_sqlalchemy/services/base.py b/ellar_sqlalchemy/services/base.py
new file mode 100644
index 0000000..dbfffa3
--- /dev/null
+++ b/ellar_sqlalchemy/services/base.py
@@ -0,0 +1,356 @@
+import os
+import typing as t
+from threading import get_ident
+from weakref import WeakKeyDictionary
+
+import sqlalchemy as sa
+import sqlalchemy.exc as sa_exc
+import sqlalchemy.orm as sa_orm
+from ellar.common.utils.importer import (
+ get_main_directory_by_stack,
+ import_from_string,
+ module_import,
+)
+from ellar.events import app_context_teardown_events
+from sqlalchemy.ext.asyncio import (
+ AsyncSession,
+ async_scoped_session,
+ async_sessionmaker,
+)
+
+from ellar_sqlalchemy.constant import (
+ DEFAULT_KEY,
+ DeclarativeBasePlaceHolder,
+)
+from ellar_sqlalchemy.model import (
+ make_metadata,
+)
+from ellar_sqlalchemy.model.base import Model
+from ellar_sqlalchemy.model.database_binds import get_database_bind, get_database_binds
+from ellar_sqlalchemy.schemas import MigrationOption
+from ellar_sqlalchemy.session import ModelSession
+
+from .metadata_engine import MetaDataEngine
+
+
+def _configure_model(
+ self: "EllarSQLAlchemyService",
+ models: t.Optional[t.List[str]] = None,
+) -> None:
+ for model in models or []:
+ module_import(model)
+
+ def model_get_session(
+ cls: t.Type[Model],
+ ) -> t.Union[sa_orm.Session, AsyncSession, t.Any]:
+ return self.session_factory()
+
+ sql_alchemy_declarative_base = import_from_string(
+ "ellar_sqlalchemy.model.base:SQLAlchemyDefaultBase"
+ )
+ base = (
+ Model if sql_alchemy_declarative_base is None else sql_alchemy_declarative_base
+ )
+
+ get_db_session = getattr(base, "get_db_session", DeclarativeBasePlaceHolder)
+ get_db_session_name = get_db_session.__name__ if get_db_session else ""
+
+ if get_db_session_name != "model_get_session":
+ base.get_db_session = classmethod(model_get_session)
+
+
+class EllarSQLAlchemyService:
+ def __init__(
+ self,
+ databases: t.Union[str, t.Dict[str, t.Any]],
+ *,
+ common_session_options: t.Optional[t.Dict[str, t.Any]] = None,
+ common_engine_options: t.Optional[t.Dict[str, t.Any]] = None,
+ models: t.Optional[t.List[str]] = None,
+ echo: bool = False,
+ root_path: t.Optional[str] = None,
+ migration_options: t.Optional[MigrationOption] = None,
+ ) -> None:
+ self._engines: WeakKeyDictionary[
+ "EllarSQLAlchemyService",
+ t.Dict[str, sa.engine.Engine],
+ ] = WeakKeyDictionary()
+
+ self._engines.setdefault(self, {})
+ self._session_options = common_session_options or {}
+
+ self._common_engine_options = common_engine_options or {}
+ self._execution_path = get_main_directory_by_stack(root_path, 2) # type:ignore[arg-type]
+
+ self.migration_options = migration_options or MigrationOption(
+ directory=get_main_directory_by_stack(
+ self._execution_path or "__main__/migrations", 2
+ )
+ )
+ self._async_session_type: bool = False
+
+ self._setup(databases, models=models, echo=echo)
+ self.session_factory = self.get_scoped_session()
+ app_context_teardown_events.connect(self._on_application_tear_down)
+
+ async def _on_application_tear_down(self) -> None:
+ res = self.session_factory.remove()
+ if isinstance(res, t.Coroutine):
+ await res
+
+ @property
+ def engines(self) -> t.Dict[str, sa.Engine]:
+ return dict(self._engines[self])
+
+ @property
+ def engine(self) -> sa.Engine:
+ assert self._engines[self].get(
+ DEFAULT_KEY
+ ), f"{self.__class__.__name__} configuration is not ready"
+ return self._engines[self][DEFAULT_KEY]
+
+ def _setup(
+ self,
+ databases: t.Union[str, t.Dict[str, t.Any]],
+ models: t.Optional[t.List[str]] = None,
+ echo: bool = False,
+ ) -> None:
+ _configure_model(self, models)
+ self._build_engines(databases, echo)
+
+ def _build_engines(
+ self, databases: t.Union[str, t.Dict[str, t.Any]], echo: bool
+ ) -> None:
+ engine_options: t.Dict[str, t.Dict[str, t.Any]] = {}
+
+ if isinstance(databases, str):
+ common_engine_options = self._common_engine_options.copy()
+ common_engine_options["url"] = databases
+ engine_options.setdefault(DEFAULT_KEY, {}).update(common_engine_options)
+
+ elif isinstance(databases, dict):
+ for key, value in databases.items():
+ engine_options[key] = self._common_engine_options.copy()
+
+ if isinstance(value, (str, sa.engine.URL)):
+ engine_options[key]["url"] = value
+ else:
+ engine_options[key].update(value)
+ else:
+ raise RuntimeError(
+ "Invalid databases data structure. Allowed datastructure, str or dict data type"
+ )
+
+ if DEFAULT_KEY not in engine_options:
+ raise RuntimeError(
+ f"`default` database must be present in databases parameter: {databases}"
+ )
+
+ engines = self._engines.setdefault(self, {})
+
+ for key, options in engine_options.items():
+ make_metadata(key)
+
+ options.setdefault("echo", echo)
+ options.setdefault("echo_pool", echo)
+
+ self._validate_engine_option_defaults(options)
+ engines[key] = self._make_engine(options)
+
+ found_async_engine = [
+ engine for engine in engines.values() if engine.dialect.is_async
+ ]
+ if found_async_engine and len(found_async_engine) != len(engines):
+ raise Exception(
+ "Databases Configuration must either be all async or all synchronous type"
+ )
+
+ self._async_session_type = bool(len(found_async_engine))
+
+ def __validate_databases_input(self, *databases: str) -> t.Union[str, t.List[str]]:
+ _databases: t.Union[str, t.List[str]] = list(databases)
+ if len(_databases) == 0:
+ _databases = "__all__"
+ return _databases
+
+ def create_all(self, *databases: str) -> None:
+ _databases = self.__validate_databases_input(*databases)
+
+ metadata_engines = self._get_metadata_and_engine(_databases)
+
+ if self._async_session_type and _databases == "__all__":
+ raise Exception(
+ "You are using asynchronous database configuration. Use `create_all_async` instead"
+ )
+
+ for metadata_engine in metadata_engines:
+ metadata_engine.create_all()
+
+ def drop_all(self, *databases: str) -> None:
+ _databases = self.__validate_databases_input(*databases)
+
+ metadata_engines = self._get_metadata_and_engine(_databases)
+
+ if self._async_session_type and _databases == "__all__":
+ raise Exception(
+ "You are using asynchronous database configuration. Use `drop_all_async` instead"
+ )
+
+ for metadata_engine in metadata_engines:
+ metadata_engine.drop_all()
+
+ def reflect(self, *databases: str) -> None:
+ _databases = self.__validate_databases_input(*databases)
+
+ metadata_engines = self._get_metadata_and_engine(_databases)
+
+ if self._async_session_type and _databases == "__all__":
+ raise Exception(
+ "You are using asynchronous database configuration. Use `reflect_async` instead"
+ )
+ for metadata_engine in metadata_engines:
+ metadata_engine.reflect()
+
+ async def create_all_async(self, *databases: str) -> None:
+ _databases = self.__validate_databases_input(*databases)
+
+ metadata_engines = self._get_metadata_and_engine(_databases)
+
+ for metadata_engine in metadata_engines:
+ if not metadata_engine.is_async():
+ metadata_engine.create_all()
+ continue
+ await metadata_engine.create_all_async()
+
+ async def drop_all_async(self, *databases: str) -> None:
+ _databases = self.__validate_databases_input(*databases)
+
+ metadata_engines = self._get_metadata_and_engine(_databases)
+
+ for metadata_engine in metadata_engines:
+ if not metadata_engine.is_async():
+ metadata_engine.drop_all()
+ continue
+ await metadata_engine.drop_all_async()
+
+ async def reflect_async(self, *databases: str) -> None:
+ _databases = self.__validate_databases_input(*databases)
+
+ metadata_engines = self._get_metadata_and_engine(_databases)
+
+ for metadata_engine in metadata_engines:
+ if not metadata_engine.is_async():
+ metadata_engine.reflect()
+ continue
+ await metadata_engine.reflect_async()
+
+ def get_scoped_session(
+ self,
+ **extra_options: t.Any,
+ ) -> t.Union[
+ sa_orm.scoped_session[sa_orm.Session],
+ async_scoped_session[t.Union[AsyncSession, t.Any]],
+ ]:
+ options = self._session_options.copy()
+ options.update(extra_options)
+
+ scope = options.pop("scopefunc", get_ident)
+
+ factory = self._make_session_factory(options)
+
+ if self._async_session_type:
+ return async_scoped_session(factory, scope) # type:ignore[arg-type]
+
+ return sa_orm.scoped_session(factory, scope) # type:ignore[arg-type]
+
+ def _make_session_factory(
+ self, options: t.Dict[str, t.Any]
+ ) -> t.Union[sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession]]:
+ if self._async_session_type:
+ options.setdefault("sync_session_class", ModelSession)
+ else:
+ options.setdefault("class_", ModelSession)
+
+ session_class = options.get("class_", options.get("sync_session_class"))
+
+ if session_class is ModelSession or issubclass(session_class, ModelSession):
+ options.update(engines=self._engines[self])
+
+ if self._async_session_type:
+ return async_sessionmaker(**options)
+
+ return sa_orm.sessionmaker(**options)
+
+ def _validate_engine_option_defaults(self, options: t.Dict[str, t.Any]) -> None:
+ url = sa.engine.make_url(options["url"])
+
+ if url.drivername in {"sqlite", "sqlite+pysqlite", "sqlite+aiosqlite"}:
+ if url.database is None or url.database in {"", ":memory:"}:
+ options["poolclass"] = sa.pool.StaticPool
+
+ if "connect_args" not in options:
+ options["connect_args"] = {}
+
+ options["connect_args"]["check_same_thread"] = False
+
+ elif self._execution_path:
+ is_uri = url.query.get("uri", False)
+
+ if is_uri:
+ db_str = url.database[5:]
+ else:
+ db_str = url.database
+
+ if not os.path.isabs(db_str):
+ root_path = os.path.join(self._execution_path, "sqlite")
+ os.makedirs(root_path, exist_ok=True)
+ db_str = os.path.join(root_path, db_str)
+
+ if is_uri:
+ db_str = f"file:{db_str}"
+
+ options["url"] = url.set(database=db_str)
+
+ elif url.drivername.startswith("mysql"):
+ # set queue defaults only when using queue pool
+ if (
+ "pool_class" not in options
+ or options["pool_class"] is sa.pool.QueuePool
+ ):
+ options.setdefault("pool_recycle", 7200)
+
+ if "charset" not in url.query:
+ options["url"] = url.update_query_dict({"charset": "utf8mb4"})
+
+ def _make_engine(self, options: t.Dict[str, t.Any]) -> sa.engine.Engine:
+ engine = sa.engine_from_config(options, prefix="")
+
+ # if engine.dialect.is_async:
+ # return AsyncEngine(engine)
+
+ return engine
+
+ def _get_metadata_and_engine(
+ self, database: t.Union[str, t.List[str]] = "__all__"
+ ) -> t.List[MetaDataEngine]:
+ engines = self._engines[self]
+
+ if database == "__all__":
+ keys: t.List[str] = list(get_database_binds())
+ elif isinstance(database, str):
+ keys = [database]
+ else:
+ keys = database
+
+ result: t.List[MetaDataEngine] = []
+
+ for key in keys:
+ try:
+ engine = engines[key]
+ except KeyError:
+ message = f"Bind key '{key}' is not in 'Database' config."
+ raise sa_exc.UnboundExecutionError(message) from None
+
+ metadata = get_database_bind(key, certain=True)
+ result.append(MetaDataEngine(metadata=metadata, engine=engine))
+ return result
diff --git a/ellar_sqlalchemy/services/metadata_engine.py b/ellar_sqlalchemy/services/metadata_engine.py
new file mode 100644
index 0000000..1d3dec3
--- /dev/null
+++ b/ellar_sqlalchemy/services/metadata_engine.py
@@ -0,0 +1,39 @@
+from __future__ import annotations
+
+import dataclasses
+
+import sqlalchemy as sa
+from sqlalchemy.ext.asyncio import AsyncEngine
+
+
+@dataclasses.dataclass
+class MetaDataEngine:
+ metadata: sa.MetaData
+ engine: sa.Engine
+
+ def is_async(self) -> bool:
+ return self.engine.dialect.is_async
+
+ def create_all(self) -> None:
+ self.metadata.create_all(bind=self.engine)
+
+ async def create_all_async(self) -> None:
+ engine = AsyncEngine(self.engine)
+ async with engine.begin() as conn:
+ await conn.run_sync(self.metadata.create_all)
+
+ def drop_all(self) -> None:
+ self.metadata.drop_all(bind=self.engine)
+
+ async def drop_all_async(self) -> None:
+ engine = AsyncEngine(self.engine)
+ async with engine.begin() as conn:
+ await conn.run_sync(self.metadata.drop_all)
+
+ def reflect(self) -> None:
+ self.metadata.reflect(bind=self.engine)
+
+ async def reflect_async(self) -> None:
+ engine = AsyncEngine(self.engine)
+ async with engine.begin() as conn:
+ await conn.run_sync(self.metadata.reflect)
diff --git a/ellar_sqlalchemy/session.py b/ellar_sqlalchemy/session.py
new file mode 100644
index 0000000..d163c84
--- /dev/null
+++ b/ellar_sqlalchemy/session.py
@@ -0,0 +1,78 @@
+import typing as t
+
+import sqlalchemy as sa
+import sqlalchemy.exc as sa_exc
+import sqlalchemy.orm as sa_orm
+
+from ellar_sqlalchemy.constant import DEFAULT_KEY
+
+EngineType = t.Optional[t.Union[sa.engine.Engine, sa.engine.Connection]]
+
+
+def _get_engine_from_clause(
+ clause: t.Optional[sa.ClauseElement],
+ engines: t.Mapping[str, sa.Engine],
+) -> t.Optional[sa.Engine]:
+ table = None
+
+ if clause is not None:
+ if isinstance(clause, sa.Table):
+ table = clause
+ elif isinstance(clause, sa.UpdateBase) and isinstance(clause.table, sa.Table):
+ table = clause.table
+
+ if table is not None and "database_bind_key" in table.metadata.info:
+ key = table.metadata.info["database_bind_key"]
+
+ if key not in engines:
+ raise sa_exc.UnboundExecutionError(
+ f"Database Bind key '{key}' is not in 'Database' config."
+ )
+
+ return engines[key]
+
+ return None
+
+
+class ModelSession(sa_orm.Session):
+ def __init__(self, engines: t.Mapping[str, sa.Engine], **kwargs: t.Any) -> None:
+ super().__init__(**kwargs)
+ self._engines = engines
+ self._model_changes: t.Dict[object, t.Tuple[t.Any, str]] = {}
+
+ def get_bind( # type:ignore[override]
+ self,
+ mapper: t.Optional[t.Any] = None,
+ clause: t.Optional[t.Any] = None,
+ bind: EngineType = None,
+ **kwargs: t.Any,
+ ) -> EngineType:
+ if bind is not None:
+ return bind
+
+ engines = self._engines
+
+ if mapper is not None:
+ try:
+ mapper = sa.inspect(mapper)
+ except sa_exc.NoInspectionAvailable as e:
+ if isinstance(mapper, type):
+ raise sa_orm.exc.UnmappedClassError(mapper) from e
+
+ raise
+
+ engine = _get_engine_from_clause(mapper.local_table, engines)
+
+ if engine is not None:
+ return engine
+
+ if clause is not None:
+ engine = _get_engine_from_clause(clause, engines)
+
+ if engine is not None:
+ return engine
+
+ if DEFAULT_KEY in engines:
+ return engines[DEFAULT_KEY]
+
+ return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs)
diff --git a/ellar_sqlalchemy/templates/basic/README b/ellar_sqlalchemy/templates/basic/README
new file mode 100644
index 0000000..eaae251
--- /dev/null
+++ b/ellar_sqlalchemy/templates/basic/README
@@ -0,0 +1 @@
+Multi-database configuration for Flask.
diff --git a/ellar_sqlalchemy/templates/basic/alembic.ini.mako b/ellar_sqlalchemy/templates/basic/alembic.ini.mako
new file mode 100644
index 0000000..fc59da6
--- /dev/null
+++ b/ellar_sqlalchemy/templates/basic/alembic.ini.mako
@@ -0,0 +1,50 @@
+# A generic, single database configuration.
+
+[alembic]
+# template used to generate migration files
+file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic,flask_migrate
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[logger_flask_migrate]
+level = INFO
+handlers =
+qualname = flask_migrate
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S
diff --git a/ellar_sqlalchemy/templates/basic/env.py b/ellar_sqlalchemy/templates/basic/env.py
new file mode 100644
index 0000000..4edeef8
--- /dev/null
+++ b/ellar_sqlalchemy/templates/basic/env.py
@@ -0,0 +1,48 @@
+import asyncio
+import typing as t
+from logging.config import fileConfig
+
+from alembic import context
+from ellar.app import current_injector
+
+from ellar_sqlalchemy.migrations import (
+ MultipleDatabaseAlembicEnvMigration,
+ SingleDatabaseAlembicEnvMigration,
+)
+from ellar_sqlalchemy.services import EllarSQLAlchemyService
+
+db_service: EllarSQLAlchemyService = current_injector.get(EllarSQLAlchemyService)
+
+# this is the Alembic Config object, which provides
+# access to the values within the .ini file in use.
+config = context.config
+
+# Interpret the config file for Python logging.
+# This line sets up loggers basically.
+fileConfig(config.config_file_name) # type:ignore[arg-type]
+# logger = logging.getLogger("alembic.env")
+
+
+AlembicEnvMigrationKlass: t.Type[
+ t.Union[MultipleDatabaseAlembicEnvMigration, SingleDatabaseAlembicEnvMigration]
+] = (
+ MultipleDatabaseAlembicEnvMigration
+ if len(db_service.engines) > 1
+ else SingleDatabaseAlembicEnvMigration
+)
+
+
+# other values from the config, defined by the needs of env.py,
+# can be acquired:
+# my_important_option = config.get_main_option("my_important_option")
+# ... etc.
+
+
+alembic_env_migration = AlembicEnvMigrationKlass(db_service)
+
+if context.is_offline_mode():
+ alembic_env_migration.run_migrations_offline(context) # type:ignore[arg-type]
+else:
+ asyncio.get_event_loop().run_until_complete(
+ alembic_env_migration.run_migrations_online(context) # type:ignore[arg-type]
+ )
diff --git a/ellar_sqlalchemy/templates/basic/script.py.mako b/ellar_sqlalchemy/templates/basic/script.py.mako
new file mode 100644
index 0000000..4b7a50f
--- /dev/null
+++ b/ellar_sqlalchemy/templates/basic/script.py.mako
@@ -0,0 +1,63 @@
+<%!
+import re
+
+%>"""${message}
+
+Revision ID: ${up_revision}
+Revises: ${down_revision | comma,n}
+Create Date: ${create_date}
+
+"""
+from alembic import op
+import sqlalchemy as sa
+${imports if imports else ""}
+
+# revision identifiers, used by Alembic.
+revision = ${repr(up_revision)}
+down_revision = ${repr(down_revision)}
+branch_labels = ${repr(branch_labels)}
+depends_on = ${repr(depends_on)}
+
+<%!
+ from ellar.app import current_injector
+ from ellar_sqlalchemy.services import EllarSQLAlchemyService
+
+ db_service = current_injector.get(EllarSQLAlchemyService)
+ db_names = list(db_service.engines.keys())
+%>
+
+% if len(db_names) > 1:
+
+def upgrade(engine_name):
+ globals()["upgrade_%s" % engine_name]()
+
+
+def downgrade(engine_name):
+ globals()["downgrade_%s" % engine_name]()
+
+
+
+## generate an "upgrade_() / downgrade_()" function
+## for each database name in the ini file.
+
+% for db_name in db_names:
+
+def upgrade_${db_name}():
+ ${context.get("%s_upgrades" % db_name, "pass")}
+
+
+def downgrade_${db_name}():
+ ${context.get("%s_downgrades" % db_name, "pass")}
+
+% endfor
+
+% else:
+
+def upgrade():
+ ${upgrades if upgrades else "pass"}
+
+
+def downgrade():
+ ${downgrades if downgrades else "pass"}
+
+% endif
diff --git a/ellar_sqlalchemy/types.py b/ellar_sqlalchemy/types.py
new file mode 100644
index 0000000..79e7775
--- /dev/null
+++ b/ellar_sqlalchemy/types.py
@@ -0,0 +1,15 @@
+import typing as t
+
+if t.TYPE_CHECKING:
+ from alembic.operations import MigrationScript
+ from alembic.runtime.migration import MigrationContext
+
+RevisionArgs = t.Union[
+ str,
+ t.Iterable[t.Optional[str]],
+ t.Iterable[str],
+]
+
+RevisionDirectiveCallable = t.Callable[
+ ["MigrationContext", RevisionArgs, t.List["MigrationScript"]], None
+]
diff --git a/examples/single-db/README.md b/examples/single-db/README.md
new file mode 100644
index 0000000..4e9d856
--- /dev/null
+++ b/examples/single-db/README.md
@@ -0,0 +1,26 @@
+## Ellar SQLAlchemy Single Database Example
+Project Description
+
+## Requirements
+Python >= 3.7
+Starlette
+Injector
+
+## Project setup
+```shell
+pip install poetry
+```
+then,
+```shell
+poetry install
+```
+### Apply Migration
+```shell
+ellar db upgrade
+```
+
+### Development Server
+```shell
+ellar runserver --reload
+```
+then, visit [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs)
\ No newline at end of file
diff --git a/examples/single-db/db/__init__.py b/examples/single-db/db/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/single-db/db/controllers.py b/examples/single-db/db/controllers.py
new file mode 100644
index 0000000..0632fb8
--- /dev/null
+++ b/examples/single-db/db/controllers.py
@@ -0,0 +1,47 @@
+"""
+Define endpoints routes in python class-based fashion
+example:
+
+@Controller("/dogs", tag="Dogs", description="Dogs Resources")
+class MyController(ControllerBase):
+ @get('/')
+ def index(self):
+ return {'detail': "Welcome Dog's Resources"}
+"""
+from ellar.common import Controller, ControllerBase, get, post, Body
+from pydantic import EmailStr
+from sqlalchemy import select
+
+from .models.users import User
+
+
+@Controller
+class DbController(ControllerBase):
+
+ @get("/")
+ def index(self):
+ return {"detail": "Welcome Db Resource"}
+
+ @post("/users")
+ def create_user(self, username: Body[str], email: Body[EmailStr]):
+ session = User.get_db_session()
+ user = User(username=username, email=email)
+
+ session.add(user)
+ session.commit()
+
+ return user.dict()
+
+ @get("/users/{user_id:int}")
+ def get_user_by_id(self, user_id: int):
+ session = User.get_db_session()
+ stmt = select(User).filter(User.id == user_id)
+ user = session.execute(stmt).scalar()
+ return user.dict()
+
+ @get("/users")
+ async def get_all_users(self):
+ session = User.get_db_session()
+ stmt = select(User)
+ rows = session.execute(stmt.offset(0).limit(100)).scalars()
+ return [row.dict() for row in rows]
diff --git a/examples/single-db/db/migrations/README b/examples/single-db/db/migrations/README
new file mode 100644
index 0000000..eaae251
--- /dev/null
+++ b/examples/single-db/db/migrations/README
@@ -0,0 +1 @@
+Multi-database configuration for Flask.
diff --git a/examples/single-db/db/migrations/alembic.ini b/examples/single-db/db/migrations/alembic.ini
new file mode 100644
index 0000000..fc59da6
--- /dev/null
+++ b/examples/single-db/db/migrations/alembic.ini
@@ -0,0 +1,50 @@
+# A generic, single database configuration.
+
+[alembic]
+# template used to generate migration files
+file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
+
+# set to 'true' to run the environment during
+# the 'revision' command, regardless of autogenerate
+# revision_environment = false
+
+
+# Logging configuration
+[loggers]
+keys = root,sqlalchemy,alembic,flask_migrate
+
+[handlers]
+keys = console
+
+[formatters]
+keys = generic
+
+[logger_root]
+level = WARN
+handlers = console
+qualname =
+
+[logger_sqlalchemy]
+level = WARN
+handlers =
+qualname = sqlalchemy.engine
+
+[logger_alembic]
+level = INFO
+handlers =
+qualname = alembic
+
+[logger_flask_migrate]
+level = INFO
+handlers =
+qualname = flask_migrate
+
+[handler_console]
+class = StreamHandler
+args = (sys.stderr,)
+level = NOTSET
+formatter = generic
+
+[formatter_generic]
+format = %(levelname)-5.5s [%(name)s] %(message)s
+datefmt = %H:%M:%S
diff --git a/examples/single-db/db/migrations/env.py b/examples/single-db/db/migrations/env.py
new file mode 100644
index 0000000..4edeef8
--- /dev/null
+++ b/examples/single-db/db/migrations/env.py
@@ -0,0 +1,48 @@
+import asyncio
+import typing as t
+from logging.config import fileConfig
+
+from alembic import context
+from ellar.app import current_injector
+
+from ellar_sqlalchemy.migrations import (
+ MultipleDatabaseAlembicEnvMigration,
+ SingleDatabaseAlembicEnvMigration,
+)
+from ellar_sqlalchemy.services import EllarSQLAlchemyService
+
+db_service: EllarSQLAlchemyService = current_injector.get(EllarSQLAlchemyService)
+
+# this is the Alembic Config object, which provides
+# access to the values within the .ini file in use.
+config = context.config
+
+# Interpret the config file for Python logging.
+# This line sets up loggers basically.
+fileConfig(config.config_file_name) # type:ignore[arg-type]
+# logger = logging.getLogger("alembic.env")
+
+
+AlembicEnvMigrationKlass: t.Type[
+ t.Union[MultipleDatabaseAlembicEnvMigration, SingleDatabaseAlembicEnvMigration]
+] = (
+ MultipleDatabaseAlembicEnvMigration
+ if len(db_service.engines) > 1
+ else SingleDatabaseAlembicEnvMigration
+)
+
+
+# other values from the config, defined by the needs of env.py,
+# can be acquired:
+# my_important_option = config.get_main_option("my_important_option")
+# ... etc.
+
+
+alembic_env_migration = AlembicEnvMigrationKlass(db_service)
+
+if context.is_offline_mode():
+ alembic_env_migration.run_migrations_offline(context) # type:ignore[arg-type]
+else:
+ asyncio.get_event_loop().run_until_complete(
+ alembic_env_migration.run_migrations_online(context) # type:ignore[arg-type]
+ )
diff --git a/examples/single-db/db/migrations/script.py.mako b/examples/single-db/db/migrations/script.py.mako
new file mode 100644
index 0000000..4b7a50f
--- /dev/null
+++ b/examples/single-db/db/migrations/script.py.mako
@@ -0,0 +1,63 @@
+<%!
+import re
+
+%>"""${message}
+
+Revision ID: ${up_revision}
+Revises: ${down_revision | comma,n}
+Create Date: ${create_date}
+
+"""
+from alembic import op
+import sqlalchemy as sa
+${imports if imports else ""}
+
+# revision identifiers, used by Alembic.
+revision = ${repr(up_revision)}
+down_revision = ${repr(down_revision)}
+branch_labels = ${repr(branch_labels)}
+depends_on = ${repr(depends_on)}
+
+<%!
+ from ellar.app import current_injector
+ from ellar_sqlalchemy.services import EllarSQLAlchemyService
+
+ db_service = current_injector.get(EllarSQLAlchemyService)
+ db_names = list(db_service.engines.keys())
+%>
+
+% if len(db_names) > 1:
+
+def upgrade(engine_name):
+ globals()["upgrade_%s" % engine_name]()
+
+
+def downgrade(engine_name):
+ globals()["downgrade_%s" % engine_name]()
+
+
+
+## generate an "upgrade_() / downgrade_()" function
+## for each database name in the ini file.
+
+% for db_name in db_names:
+
+def upgrade_${db_name}():
+ ${context.get("%s_upgrades" % db_name, "pass")}
+
+
+def downgrade_${db_name}():
+ ${context.get("%s_downgrades" % db_name, "pass")}
+
+% endfor
+
+% else:
+
+def upgrade():
+ ${upgrades if upgrades else "pass"}
+
+
+def downgrade():
+ ${downgrades if downgrades else "pass"}
+
+% endif
diff --git a/examples/single-db/db/migrations/versions/2023_12_30_2053-b7712f83d45b_first_migration.py b/examples/single-db/db/migrations/versions/2023_12_30_2053-b7712f83d45b_first_migration.py
new file mode 100644
index 0000000..562feb5
--- /dev/null
+++ b/examples/single-db/db/migrations/versions/2023_12_30_2053-b7712f83d45b_first_migration.py
@@ -0,0 +1,39 @@
+"""first migration
+
+Revision ID: b7712f83d45b
+Revises:
+Create Date: 2023-12-30 20:53:37.393009
+
+"""
+from alembic import op
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = 'b7712f83d45b'
+down_revision = None
+branch_labels = None
+depends_on = None
+
+
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('user',
+ sa.Column('id', sa.Integer(), nullable=False),
+ sa.Column('username', sa.String(), nullable=False),
+ sa.Column('email', sa.String(), nullable=False),
+ sa.Column('created_date', sa.DateTime(), nullable=False),
+ sa.Column('time_updated', sa.DateTime(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name=op.f('pk_user')),
+ sa.UniqueConstraint('username', name=op.f('uq_user_username'))
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('user')
+ # ### end Alembic commands ###
+
diff --git a/examples/single-db/db/models/__init__.py b/examples/single-db/db/models/__init__.py
new file mode 100644
index 0000000..67c9337
--- /dev/null
+++ b/examples/single-db/db/models/__init__.py
@@ -0,0 +1,5 @@
+from .users import User
+
+__all__ = [
+ 'User',
+]
diff --git a/examples/single-db/db/models/base.py b/examples/single-db/db/models/base.py
new file mode 100644
index 0000000..0f35189
--- /dev/null
+++ b/examples/single-db/db/models/base.py
@@ -0,0 +1,27 @@
+from datetime import datetime
+from sqlalchemy import DateTime, func, MetaData
+from sqlalchemy.orm import Mapped, mapped_column
+
+from ellar_sqlalchemy.model import Model
+
+convention = {
+ "ix": "ix_%(column_0_label)s",
+ "uq": "uq_%(table_name)s_%(column_0_name)s",
+ "ck": "ck_%(table_name)s_%(constraint_name)s",
+ "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
+ "pk": "pk_%(table_name)s",
+}
+
+
+class Base(Model, as_base=True):
+ __database__ = 'default'
+
+ metadata = MetaData(naming_convention=convention)
+
+ created_date: Mapped[datetime] = mapped_column(
+ "created_date", DateTime, default=datetime.utcnow, nullable=False
+ )
+
+ time_updated: Mapped[datetime] = mapped_column(
+ "time_updated", DateTime, nullable=False, default=datetime.utcnow, onupdate=func.now()
+ )
diff --git a/examples/single-db/db/models/users.py b/examples/single-db/db/models/users.py
new file mode 100644
index 0000000..489a36d
--- /dev/null
+++ b/examples/single-db/db/models/users.py
@@ -0,0 +1,18 @@
+
+from sqlalchemy import Integer, String
+from sqlalchemy.orm import Mapped, mapped_column
+from .base import Base
+
+class User(Base):
+ id: Mapped[int] = mapped_column(Integer, primary_key=True)
+ username: Mapped[str] = mapped_column(String, unique=True, nullable=False)
+ email: Mapped[str] = mapped_column(String)
+
+
+
+assert getattr(User, '__dnd__', None) == 'Ellar'
+
+# assert session
+
+
+
diff --git a/examples/single-db/db/module.py b/examples/single-db/db/module.py
new file mode 100644
index 0000000..7ff4082
--- /dev/null
+++ b/examples/single-db/db/module.py
@@ -0,0 +1,56 @@
+"""
+@Module(
+ controllers=[MyController],
+ providers=[
+ YourService,
+ ProviderConfig(IService, use_class=AService),
+ ProviderConfig(IFoo, use_value=FooService()),
+ ],
+ routers=(routerA, routerB)
+ statics='statics',
+ template='template_folder',
+ # base_directory -> default is the `db` folder
+)
+class MyModule(ModuleBase):
+ def register_providers(self, container: Container) -> None:
+ # for more complicated provider registrations
+ pass
+
+"""
+from ellar.app import App
+from ellar.common import Module, IApplicationStartup
+from ellar.core import ModuleBase
+from ellar.di import Container
+from ellar_sqlalchemy import EllarSQLAlchemyModule, EllarSQLAlchemyService
+
+from .controllers import DbController
+
+
+@Module(
+ controllers=[DbController],
+ providers=[],
+ routers=[],
+ modules=[
+ EllarSQLAlchemyModule.setup(
+ databases={
+ 'default': 'sqlite:///project.db',
+ },
+ echo=True,
+ migration_options={
+ 'directory': '__main__/migrations'
+ },
+ models=['db.models']
+ )
+ ]
+)
+class DbModule(ModuleBase, IApplicationStartup):
+ """
+ Db Module
+ """
+
+ async def on_startup(self, app: App) -> None:
+ db_service = app.injector.get(EllarSQLAlchemyService)
+ # db_service.create_all()
+
+ def register_providers(self, container: Container) -> None:
+ """for more complicated provider registrations, use container.register_instance(...) """
\ No newline at end of file
diff --git a/examples/single-db/db/sqlite/project.db b/examples/single-db/db/sqlite/project.db
new file mode 100644
index 0000000..77a0830
Binary files /dev/null and b/examples/single-db/db/sqlite/project.db differ
diff --git a/examples/single-db/db/tests/__init__.py b/examples/single-db/db/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/single-db/pyproject.toml b/examples/single-db/pyproject.toml
new file mode 100644
index 0000000..c9bb429
--- /dev/null
+++ b/examples/single-db/pyproject.toml
@@ -0,0 +1,26 @@
+[tool.poetry]
+name = "single-db"
+version = "0.1.0"
+description = "Demonstrating SQLAlchemy with Ellar"
+authors = ["Ezeudoh Tochukwu "]
+license = "MIT"
+readme = "README.md"
+packages = [{include = "single_db"}]
+
+[tool.poetry.dependencies]
+python = "^3.8"
+ellar-cli = "^0.2.6"
+ellar = "^0.6.2"
+
+
+[build-system]
+requires = ["poetry-core"]
+build-backend = "poetry.core.masonry.api"
+
+[ellar]
+default = "single_db"
+[ellar.projects.single_db]
+project-name = "single_db"
+application = "single_db.server:application"
+config = "single_db.config:DevelopmentConfig"
+root-module = "single_db.root_module:ApplicationModule"
diff --git a/examples/single-db/single_db/__init__.py b/examples/single-db/single_db/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/examples/single-db/single_db/config.py b/examples/single-db/single_db/config.py
new file mode 100644
index 0000000..d393f88
--- /dev/null
+++ b/examples/single-db/single_db/config.py
@@ -0,0 +1,82 @@
+"""
+Application Configurations
+Default Ellar Configurations are exposed here through `ConfigDefaultTypesMixin`
+Make changes and define your own configurations specific to your application
+
+export ELLAR_CONFIG_MODULE=ellar_sqlachemy_example.config:DevelopmentConfig
+"""
+
+import typing as t
+
+from ellar.pydantic import ENCODERS_BY_TYPE as encoders_by_type
+from starlette.middleware import Middleware
+from ellar.common import IExceptionHandler, JSONResponse
+from ellar.core import ConfigDefaultTypesMixin
+from ellar.core.versioning import BaseAPIVersioning, DefaultAPIVersioning
+
+
+class BaseConfig(ConfigDefaultTypesMixin):
+ DEBUG: bool = False
+
+ DEFAULT_JSON_CLASS: t.Type[JSONResponse] = JSONResponse
+ SECRET_KEY: str = "ellar_wltsSVEySCVC3xC2i4a0y40jlbcjTupkCX0TSoUT-R4"
+
+ # injector auto_bind = True allows you to resolve types that are not registered on the container
+ # For more info, read: https://injector.readthedocs.io/en/latest/index.html
+ INJECTOR_AUTO_BIND = False
+
+ # jinja Environment options
+ # https://jinja.palletsprojects.com/en/3.0.x/api/#high-level-api
+ JINJA_TEMPLATES_OPTIONS: t.Dict[str, t.Any] = {}
+
+ # Application route versioning scheme
+ VERSIONING_SCHEME: BaseAPIVersioning = DefaultAPIVersioning()
+
+ # Enable or Disable Application Router route searching by appending backslash
+ REDIRECT_SLASHES: bool = False
+
+ # Define references to static folders in python packages.
+ # eg STATIC_FOLDER_PACKAGES = [('boostrap4', 'statics')]
+ STATIC_FOLDER_PACKAGES: t.Optional[t.List[t.Union[str, t.Tuple[str, str]]]] = []
+
+ # Define references to static folders defined within the project
+ STATIC_DIRECTORIES: t.Optional[t.List[t.Union[str, t.Any]]] = []
+
+ # static route path
+ STATIC_MOUNT_PATH: str = "/static"
+
+ CORS_ALLOW_ORIGINS: t.List[str] = ["*"]
+ CORS_ALLOW_METHODS: t.List[str] = ["*"]
+ CORS_ALLOW_HEADERS: t.List[str] = ["*"]
+ ALLOWED_HOSTS: t.List[str] = ["*"]
+
+ # Application middlewares
+ MIDDLEWARE: t.Sequence[Middleware] = []
+
+ # A dictionary mapping either integer status codes,
+ # or exception class types onto callables which handle the exceptions.
+ # Exception handler callables should be of the form
+ # `handler(context:IExecutionContext, exc: Exception) -> response`
+ # and may be either standard functions, or async functions.
+ EXCEPTION_HANDLERS: t.List[IExceptionHandler] = []
+
+ # Object Serializer custom encoders
+ SERIALIZER_CUSTOM_ENCODER: t.Dict[
+ t.Any, t.Callable[[t.Any], t.Any]
+ ] = encoders_by_type
+
+
+class DevelopmentConfig(BaseConfig):
+ DEBUG: bool = True
+ # Configuration through Confog
+ SQLALCHEMY_CONFIG: t.Dict[str, t.Any] = {
+ 'databases': {
+ 'default': 'sqlite+aiosqlite:///project.db',
+ # 'db2': 'sqlite+aiosqlite:///project2.db',
+ },
+ 'echo': True,
+ 'migration_options': {
+ 'directory': '__main__/migrations'
+ },
+ 'models': ['db.models']
+ }
\ No newline at end of file
diff --git a/examples/single-db/single_db/root_module.py b/examples/single-db/single_db/root_module.py
new file mode 100644
index 0000000..710e4a2
--- /dev/null
+++ b/examples/single-db/single_db/root_module.py
@@ -0,0 +1,10 @@
+from ellar.common import Module, exception_handler, IExecutionContext, JSONResponse, Response
+from ellar.core import ModuleBase, LazyModuleImport as lazyLoad
+from ellar.samples.modules import HomeModule
+
+
+@Module(modules=[HomeModule, lazyLoad('db.module:DbModule')])
+class ApplicationModule(ModuleBase):
+ @exception_handler(404)
+ def exception_404_handler(cls, ctx: IExecutionContext, exc: Exception) -> Response:
+ return JSONResponse(dict(detail="Resource not found."), status_code=404)
\ No newline at end of file
diff --git a/examples/single-db/single_db/server.py b/examples/single-db/single_db/server.py
new file mode 100644
index 0000000..4b06927
--- /dev/null
+++ b/examples/single-db/single_db/server.py
@@ -0,0 +1,28 @@
+import os
+
+from ellar.app import AppFactory
+from ellar.common.constants import ELLAR_CONFIG_MODULE
+from ellar.core import LazyModuleImport as lazyLoad
+from ellar.openapi import OpenAPIDocumentModule, OpenAPIDocumentBuilder, SwaggerUI
+
+application = AppFactory.create_from_app_module(
+ lazyLoad("single_db.root_module:ApplicationModule"),
+ config_module=os.environ.get(
+ ELLAR_CONFIG_MODULE, "single_db.config:DevelopmentConfig"
+ ),
+ global_guards=[]
+)
+
+document_builder = OpenAPIDocumentBuilder()
+document_builder.set_title('Ellar Sqlalchemy Single Database Example') \
+ .set_version('1.0.2') \
+ .set_contact(name='Author Name', url='https://www.author-name.com', email='authorname@gmail.com') \
+ .set_license('MIT Licence', url='https://www.google.com')
+
+document = document_builder.build_document(application)
+module = OpenAPIDocumentModule.setup(
+ document=document,
+ docs_ui=SwaggerUI(dark_theme=True),
+ guards=[]
+)
+application.install_module(module)
\ No newline at end of file
diff --git a/examples/single-db/tests/conftest.py b/examples/single-db/tests/conftest.py
new file mode 100644
index 0000000..b18f954
--- /dev/null
+++ b/examples/single-db/tests/conftest.py
@@ -0,0 +1 @@
+from ellar.testing import Test, TestClient
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..bddfc1a
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,94 @@
+[build-system]
+requires = ["flit_core >=2,<4"]
+build-backend = "flit_core.buildapi"
+
+[tool.flit.module]
+name = "ellar_sqlalchemy"
+
+
+[project]
+name = "ellar-sqlalchemy"
+authors = [
+ {name = "Ezeudoh Tochukwu", email = "tochukwu.ezeudoh@gmail.com"},
+]
+dynamic = ["version", "description"]
+requires-python = ">=3.8"
+readme = "README.md"
+home-page = "https://github.com/python-ellar/ellar-sqlalchemy"
+classifiers = [
+ "Development Status :: 5 - Production/Stable",
+ "Environment :: Web Environment",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: MIT License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python",
+ "Topic :: Internet :: WWW/HTTP :: Dynamic Content",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3 :: Only",
+]
+
+dependencies = [
+ "ellar >= 0.6.2",
+ "sqlalchemy >=2.0.16",
+ "alembic >= 1.10.0",
+]
+
+dev = [
+ "pre-commit"
+]
+
+[project.urls]
+Documentation = "https://github.com/python-ellar/ellar-sqlalchemy"
+Source = "https://github.com/python-ellar/ellar-sqlalchemy"
+Homepage = "https://python-ellar.github.io/ellar-sqlalchemy/"
+"Bug Tracker" = "https://github.com/python-ellar/ellar-sqlalchemy/issues"
+
+[project.optional-dependencies]
+test = [
+ "pytest >= 7.1.3,<8.0.0",
+ "pytest-cov >= 2.12.0,<5.0.0",
+ "ruff ==0.1.7",
+ "mypy == 1.7.1",
+ "autoflake",
+]
+
+[tool.ruff]
+select = [
+ "E", # pycodestyle errors
+ "W", # pycodestyle warnings
+ "F", # pyflakes
+ "I", # isort
+ "C", # flake8-comprehensions
+ "B", # flake8-bugbear
+]
+ignore = [
+ "E501", # line too long, handled by black
+ "B008", # do not perform function calls in argument defaults
+ "C901", # too complex
+]
+
+[tool.ruff.per-file-ignores]
+"__init__.py" = ["F401"]
+
+[tool.ruff.isort]
+known-third-party = ["ellar"]
+
+[tool.mypy]
+python_version = "3.8"
+show_error_codes = true
+pretty = true
+strict = true
+# db.Model attribute doesn't recognize subclassing
+disable_error_code = ["name-defined", 'union-attr']
+# db.Model is Any
+disallow_subclassing_any = false
+[[tool.mypy.overrides]]
+module = "ellar_sqlalchemy.cli.commands"
+ignore_errors = true
+[[tool.mypy.overrides]]
+module = "ellar_sqlalchemy.migrations.*"
+disable_error_code = ["arg-type", 'union-attr']
\ No newline at end of file
diff --git a/pytest.ini b/pytest.ini
new file mode 100644
index 0000000..b1f9d2f
--- /dev/null
+++ b/pytest.ini
@@ -0,0 +1,8 @@
+[pytest]
+addopts = --strict-config --strict-markers
+xfail_strict = true
+junit_family = "xunit2"
+norecursedirs = examples/*
+
+[pytest-watch]
+runner= pytest --failed-first --maxfail=1 --no-success-flaky-report
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..e69de29