diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..4cafb10 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,9 @@ +*.pyc +.venv* +.vscode +.mypy_cache +.coverage +htmlcov + +dist +test.py diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..9d5b626 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +[flake8] +max-line-length = 88 +ignore = E203, E241, E501, W503, F811 +exclude = + .git, + __pycache__ + .history + tests/demo_project diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..b9038ca --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: "pip" + directory: "/" + schedule: + interval: "monthly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: monthly diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml new file mode 100644 index 0000000..4d56161 --- /dev/null +++ b/.github/workflows/publish.yml @@ -0,0 +1,25 @@ +name: Publish + +on: + release: + types: [published] + workflow_dispatch: + +jobs: + publish: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Install Flit + run: pip install flit + - name: Install Dependencies + run: flit install --symlink + - name: Publish + env: + FLIT_USERNAME: ${{ secrets.FLIT_USERNAME }} + FLIT_PASSWORD: ${{ secrets.FLIT_PASSWORD }} + run: flit publish diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..860f59f --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,25 @@ +name: Test + +on: + push: + pull_request: + types: [assigned, opened, synchronize, reopened] + +jobs: + test_coverage: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Install Flit + run: pip install flit + - name: Install Dependencies + run: flit install --symlink + - name: Test + run: make test-cov + - name: Coverage + uses: codecov/codecov-action@v3.1.4 diff --git a/.github/workflows/test_full.yml b/.github/workflows/test_full.yml new file mode 100644 index 0000000..f459fdd --- /dev/null +++ b/.github/workflows/test_full.yml @@ -0,0 +1,43 @@ +name: Full Test + +on: + push: + pull_request: + types: [assigned, opened, synchronize, reopened] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install Flit + run: pip install flit + - name: Install Dependencies + run: flit install --symlink + - name: Test + run: pytest tests + + codestyle: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: 3.8 + - name: Install Flit + run: pip install flit + - name: Install Dependencies + run: flit install --symlink + - name: Linting check + run: ruff check ellar_sqlalchemy tests + - name: mypy + run: mypy ellar_sqlalchemy diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a47d824 --- /dev/null +++ b/.gitignore @@ -0,0 +1,129 @@ +*.pyc + +# Byte-compiled / optimized / DLL files +__pycache__ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +# *.mo Needs to come with the package +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +.vscode +.mypy_cache +.coverage +htmlcov + +dist +test.py + +docs/site + +.DS_Store +.idea +local_install.sh +dist +test.py + +docs/site +site/ diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..1815422 --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,3 @@ +[settings] +profile = black +combine_as_imports = true diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..0fbadcd --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,46 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: check-merge-conflict +- repo: https://github.com/asottile/yesqa + rev: v1.3.0 + hooks: + - id: yesqa +- repo: local + hooks: + - id: code_formatting + args: [] + name: Code Formatting + entry: "make fmt" + types: [python] + language_version: python3.8 + language: python + - id: code_linting + args: [ ] + name: Code Linting + entry: "make lint" + types: [ python ] + language_version: python3.8 + language: python +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.3.0 + hooks: + - id: end-of-file-fixer + exclude: >- + ^examples/[^/]*\.svg$ + - id: requirements-txt-fixer + - id: trailing-whitespace + types: [python] + - id: check-case-conflict + - id: check-json + - id: check-xml + - id: check-executables-have-shebangs + - id: check-toml + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: check-added-large-files + - id: check-symlinks + - id: debug-statements + exclude: ^tests/ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..0714b54 --- /dev/null +++ b/Makefile @@ -0,0 +1,38 @@ +.PHONY: help docs +.DEFAULT_GOAL := help + +help: + @fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +clean: ## Removing cached python compiled files + find . -name \*pyc | xargs rm -fv + find . -name \*pyo | xargs rm -fv + find . -name \*~ | xargs rm -fv + find . -name __pycache__ | xargs rm -rfv + find . -name .ruff_cache | xargs rm -rfv + +install: ## Install dependencies + flit install --deps develop --symlink + +install-full: ## Install dependencies + make install + pre-commit install -f + +lint:fmt ## Run code linters + ruff check ellar_sqlalchemy tests + mypy ellar_sqlalchemy + +fmt format:clean ## Run code formatters + ruff format ellar_sqlalchemy tests + ruff check --fix ellar_sqlalchemy tests + +test: ## Run tests + pytest tests + +test-cov: ## Run tests with coverage + pytest --cov=ellar_sqlalchemy --cov-report term-missing tests + +pre-commit-lint: ## Runs Requires commands during pre-commit + make clean + make fmt + make lint diff --git a/README.md b/README.md new file mode 100644 index 0000000..a14fc58 --- /dev/null +++ b/README.md @@ -0,0 +1,157 @@ +

+ Ellar Logo +

+ +![Test](https://github.com/eadwinCode/ellar-sqlachemy/actions/workflows/test_full.yml/badge.svg) +![Coverage](https://img.shields.io/codecov/c/github/python-ellar/ellar-sqlachemy) +[![PyPI version](https://badge.fury.io/py/ellar-sqlachemy.svg)](https://badge.fury.io/py/ellar-sqlachemy) +[![PyPI version](https://img.shields.io/pypi/v/ellar-sqlachemy.svg)](https://pypi.python.org/pypi/ellar-sqlachemy) +[![PyPI version](https://img.shields.io/pypi/pyversions/ellar-sqlachemy.svg)](https://pypi.python.org/pypi/ellar-sqlachemy) + +## Project Status +- 70% done +- SQLAlchemy Table support with ModelSession +- Migration custom revision directives +- Documentation +- File Field +- Image Field + +## Introduction +Ellar SQLAlchemy Module simplifies the integration of SQLAlchemy and Alembic migration tooling into your ellar application. + +## Installation +```shell +$(venv) pip install ellar-sqlalchemy +``` + +## Features +- Automatic table name +- Session management during request and after request +- Support both async/sync SQLAlchemy operations in Session, Engine, and Connection. +- Multiple Database Support +- Database migrations for both single and multiple databases either async/sync database engine + +## **Usage** +In your ellar application, create a module called `db` or any name of your choice, +```shell +ellar create-module db +``` +Then, in `models/base.py` define your model base as shown below: + +```python +# db/models/base.py +from datetime import datetime +from sqlalchemy import DateTime, func +from sqlalchemy.orm import Mapped, mapped_column +from ellar_sqlalchemy.model import Model + + +class Base(Model, as_base=True): + __database__ = 'default' + + created_date: Mapped[datetime] = mapped_column( + "created_date", DateTime, default=datetime.utcnow, nullable=False + ) + + time_updated: Mapped[datetime] = mapped_column( + "time_updated", DateTime, nullable=False, default=datetime.utcnow, onupdate=func.now() + ) +``` + +Use `Base` to create other models, like users in `User` in +```python +# db/models/users.py +from sqlalchemy import Integer, String +from sqlalchemy.orm import Mapped, mapped_column +from .base import Base + + +class User(Base): + id: Mapped[int] = mapped_column(Integer, primary_key=True) + username: Mapped[str] = mapped_column(String, unique=True, nullable=False) + email: Mapped[str] = mapped_column(String) +``` + +### Configure Module +```python +# db/module.py +from ellar.app import App +from ellar.common import Module, IApplicationStartup +from ellar.core import ModuleBase +from ellar.di import Container +from ellar_sqlalchemy import EllarSQLAlchemyModule, EllarSQLAlchemyService + +from .controllers import DbController + +@Module( + controllers=[DbController], + providers=[], + routers=[], + modules=[ + EllarSQLAlchemyModule.setup( + databases={ + 'default': 'sqlite:///project.db', + }, + echo=True, + migration_options={ + 'directory': '__main__/migrations' + }, + models=['db.models.users'] + ) + ] +) +class DbModule(ModuleBase, IApplicationStartup): + """ + Db Module + """ + + async def on_startup(self, app: App) -> None: + db_service = app.injector.get(EllarSQLAlchemyService) + db_service.create_all() + + def register_providers(self, container: Container) -> None: + """for more complicated provider registrations, use container.register_instance(...) """ +``` + +### Model Usage +Database session exist at model level and can be accessed through `model.get_db_session()` eg, `User.get_db_session()` +```python +# db/models/controllers.py +from ellar.common import Controller, ControllerBase, get, post, Body +from pydantic import EmailStr +from sqlalchemy import select + +from .models.users import User + + +@Controller +class DbController(ControllerBase): + @post("/users") + async def create_user(self, username: Body[str], email: Body[EmailStr]): + session = User.get_db_session() + user = User(username=username, email=email) + + session.add(user) + session.commit() + + return user.dict() + + + @get("/users/{user_id:int}") + def get_user_by_id(self, user_id: int): + session = User.get_db_session() + stmt = select(User).filter(User.id==user_id) + user = session.execute(stmt).scalar() + return user.dict() + + @get("/users") + async def get_all_users(self): + session = User.get_db_session() + stmt = select(User) + rows = session.execute(stmt.offset(0).limit(100)).scalars() + return [row.dict() for row in rows] +``` + +## License + +Ellar is [MIT licensed](LICENSE). diff --git a/ellar_sqlalchemy/__init__.py b/ellar_sqlalchemy/__init__.py new file mode 100644 index 0000000..e50183a --- /dev/null +++ b/ellar_sqlalchemy/__init__.py @@ -0,0 +1,8 @@ +"""Ellar SQLAlchemy Module - Adds support for SQLAlchemy and Alembic package to your Ellar web Framework""" + +__version__ = "0.0.1" + +from .module import EllarSQLAlchemyModule +from .services import EllarSQLAlchemyService + +__all__ = ["EllarSQLAlchemyModule", "EllarSQLAlchemyService"] diff --git a/ellar_sqlalchemy/cli/__init__.py b/ellar_sqlalchemy/cli/__init__.py new file mode 100644 index 0000000..6951cd7 --- /dev/null +++ b/ellar_sqlalchemy/cli/__init__.py @@ -0,0 +1,3 @@ +from .commands import db as DBCommands + +__all__ = ["DBCommands"] diff --git a/ellar_sqlalchemy/cli/commands.py b/ellar_sqlalchemy/cli/commands.py new file mode 100644 index 0000000..ba624fe --- /dev/null +++ b/ellar_sqlalchemy/cli/commands.py @@ -0,0 +1,404 @@ +import click +from ellar.app import current_injector + +from ellar_sqlalchemy.services import EllarSQLAlchemyService + +from .handlers import CLICommandHandlers + + +@click.group() +def db(): + """- Perform Alembic Database Commands -""" + pass + + +def _get_handler_context(ctx: click.Context) -> CLICommandHandlers: + db_service = current_injector.get(EllarSQLAlchemyService) + return CLICommandHandlers(db_service) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option("-m", "--message", default=None, help="Revision message") +@click.option( + "--autogenerate", + is_flag=True, + help=( + "Populate revision script with candidate migration " + "operations, based on comparison of database to model" + ), +) +@click.option( + "--sql", + is_flag=True, + help="Don't emit SQL to database - dump to standard output " "instead", +) +@click.option( + "--head", + default="head", + help="Specify head revision or @head to base new " "revision on", +) +@click.option( + "--splice", + is_flag=True, + help='Allow a non-head revision as the "head" to splice onto', +) +@click.option( + "--branch-label", + default=None, + help="Specify a branch label to apply to the new revision", +) +@click.option( + "--version-path", + default=None, + help="Specify specific path from config for version file", +) +@click.option( + "--rev-id", + default=None, + help="Specify a hardcoded revision id instead of generating " "one", +) +@click.pass_context +def revision( + ctx, + directory, + message, + autogenerate, + sql, + head, + splice, + branch_label, + version_path, + rev_id, +): + """- Create a new revision file.""" + handler = _get_handler_context(ctx) + handler.revision( + directory, + message, + autogenerate, + sql, + head, + splice, + branch_label, + version_path, + rev_id, + ) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option("-m", "--message", default=None, help="Revision message") +@click.option( + "--sql", + is_flag=True, + help="Don't emit SQL to database - dump to standard output " "instead", +) +@click.option( + "--head", + default="head", + help="Specify head revision or @head to base new " "revision on", +) +@click.option( + "--splice", + is_flag=True, + help='Allow a non-head revision as the "head" to splice onto', +) +@click.option( + "--branch-label", + default=None, + help="Specify a branch label to apply to the new revision", +) +@click.option( + "--version-path", + default=None, + help="Specify specific path from config for version file", +) +@click.option( + "--rev-id", + default=None, + help="Specify a hardcoded revision id instead of generating " "one", +) +@click.option( + "-x", + "--x-arg", + multiple=True, + help="Additional arguments consumed by custom env.py scripts", +) +@click.pass_context +def migrate( + ctx, + directory, + message, + sql, + head, + splice, + branch_label, + version_path, + rev_id, + x_arg, +): + """- Autogenerate a new revision file (Alias for + 'revision --autogenerate')""" + handler = _get_handler_context(ctx) + handler.migrate( + directory, + message, + sql, + head, + splice, + branch_label, + version_path, + rev_id, + x_arg, + ) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.argument("revision", default="head") +@click.pass_context +def edit(ctx, directory, revision): + """- Edit a revision file""" + handler = _get_handler_context(ctx) + handler.edit(directory, revision) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option("-m", "--message", default=None, help="Merge revision message") +@click.option( + "--branch-label", + default=None, + help="Specify a branch label to apply to the new revision", +) +@click.option( + "--rev-id", + default=None, + help="Specify a hardcoded revision id instead of generating " "one", +) +@click.argument("revisions", nargs=-1) +@click.pass_context +def merge(ctx, directory, message, branch_label, rev_id, revisions): + """- Merge two revisions together, creating a new revision file""" + handler = _get_handler_context(ctx) + handler.merge(directory, revisions, message, branch_label, rev_id) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option( + "--sql", + is_flag=True, + help="Don't emit SQL to database - dump to standard output " "instead", +) +@click.option( + "--tag", + default=None, + help='Arbitrary "tag" name - can be used by custom env.py ' "scripts", +) +@click.option( + "-x", + "--x-arg", + multiple=True, + help="Additional arguments consumed by custom env.py scripts", +) +@click.argument("revision", default="head") +@click.pass_context +def upgrade(ctx, directory, sql, tag, x_arg, revision): + """- Upgrade to a later version""" + handler = _get_handler_context(ctx) + handler.upgrade(directory, revision, sql, tag, x_arg) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option( + "--sql", + is_flag=True, + help="Don't emit SQL to database - dump to standard output " "instead", +) +@click.option( + "--tag", + default=None, + help='Arbitrary "tag" name - can be used by custom env.py ' "scripts", +) +@click.option( + "-x", + "--x-arg", + multiple=True, + help="Additional arguments consumed by custom env.py scripts", +) +@click.argument("revision", default="-1") +@click.pass_context +def downgrade(ctx: click.Context, directory, sql, tag, x_arg, revision): + """- Revert to a previous version""" + handler = _get_handler_context(ctx) + handler.downgrade(directory, revision, sql, tag, x_arg) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.argument("revision", default="head") +@click.pass_context +def show(ctx: click.Context, directory, revision): + """- Show the revision denoted by the given symbol.""" + handler = _get_handler_context(ctx) + handler.show(directory, revision) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option( + "-r", + "--rev-range", + default=None, + help="Specify a revision range; format is [start]:[end]", +) +@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output") +@click.option( + "-i", + "--indicate-current", + is_flag=True, + help="Indicate current version (Alembic 0.9.9 or greater is " "required)", +) +@click.pass_context +def history(ctx: click.Context, directory, rev_range, verbose, indicate_current): + """- List changeset scripts in chronological order.""" + handler = _get_handler_context(ctx) + handler.history(directory, rev_range, verbose, indicate_current) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output") +@click.option( + "--resolve-dependencies", + is_flag=True, + help="Treat dependency versions as down revisions", +) +@click.pass_context +def heads(ctx: click.Context, directory, verbose, resolve_dependencies): + """- Show current available heads in the script directory""" + handler = _get_handler_context(ctx) + handler.heads(directory, verbose, resolve_dependencies) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output") +@click.pass_context +def branches(ctx, directory, verbose): + """- Show current branch points""" + handler = _get_handler_context(ctx) + handler.branches(directory, verbose) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option("-v", "--verbose", is_flag=True, help="Use more verbose output") +@click.pass_context +def current(ctx: click.Context, directory, verbose): + """- Display the current revision for each database.""" + handler = _get_handler_context(ctx) + handler.current(directory, verbose) + + +@db.command() +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option( + "--sql", + is_flag=True, + help="Don't emit SQL to database - dump to standard output " "instead", +) +@click.option( + "--tag", + default=None, + help='Arbitrary "tag" name - can be used by custom env.py ' "scripts", +) +@click.argument("revision", default="head") +@click.pass_context +def stamp(ctx: click.Context, directory, sql, tag, revision): + """- 'stamp' the revision table with the given revision; don't run any + migrations""" + handler = _get_handler_context(ctx) + handler.stamp(directory, revision, sql, tag) + + +@db.command("init-migration") +@click.option( + "-d", + "--directory", + default=None, + help='Migration script directory (default is "migrations")', +) +@click.option( + "--package", + is_flag=True, + help="Write empty __init__.py files to the environment and " "version locations", +) +@click.pass_context +def init(ctx: click.Context, directory, package): + """Creates a new migration repository.""" + handler = _get_handler_context(ctx) + handler.alembic_init(directory, package) diff --git a/ellar_sqlalchemy/cli/handlers.py b/ellar_sqlalchemy/cli/handlers.py new file mode 100644 index 0000000..4a29f40 --- /dev/null +++ b/ellar_sqlalchemy/cli/handlers.py @@ -0,0 +1,264 @@ +from __future__ import annotations + +import argparse +import logging +import os +import sys +import typing as t +from functools import wraps +from pathlib import Path + +from alembic import command +from alembic.config import Config as AlembicConfig +from alembic.util.exc import CommandError +from ellar.app import App + +from ellar_sqlalchemy.services import EllarSQLAlchemyService + +log = logging.getLogger(__name__) +RevIdType = t.Union[str, t.List[str], t.Tuple[str, ...]] + + +class Config(AlembicConfig): + def get_template_directory(self) -> str: + package_dir = os.path.abspath(Path(__file__).parent.parent) + return os.path.join(package_dir, "templates") + + +def _catch_errors(f: t.Callable) -> t.Callable: # type:ignore[type-arg] + @wraps(f) + def wrapped(*args: t.Any, **kwargs: t.Any) -> None: + try: + f(*args, **kwargs) + except (CommandError, RuntimeError) as exc: + log.error("Error: " + str(exc)) + sys.exit(1) + + return wrapped + + +class CLICommandHandlers: + def __init__(self, db_service: EllarSQLAlchemyService) -> None: + self.db_service = db_service + + def get_config( + self, + directory: t.Optional[t.Any] = None, + x_arg: t.Optional[t.Any] = None, + opts: t.Optional[t.Any] = None, + ) -> Config: + directory = ( + str(directory) if directory else self.db_service.migration_options.directory + ) + + config = Config(os.path.join(directory, "alembic.ini")) + + config.set_main_option("script_location", directory) + config.set_main_option( + "sqlalchemy.url", str(self.db_service.engine.url).replace("%", "%%") + ) + + if config.cmd_opts is None: + config.cmd_opts = argparse.Namespace() + + for opt in opts or []: + setattr(config.cmd_opts, opt, True) + + if not hasattr(config.cmd_opts, "x"): + if x_arg is not None: + config.cmd_opts.x = [] + + if isinstance(x_arg, list) or isinstance(x_arg, tuple): + for x in x_arg: + config.cmd_opts.x.append(x) + else: + config.cmd_opts.x.append(x_arg) + else: + config.cmd_opts.x = None + return config + + @_catch_errors + def alembic_init(self, directory: str | None = None, package: bool = False) -> None: + """Creates a new migration repository""" + if directory is None: + directory = self.db_service.migration_options.directory + + config = Config() + config.set_main_option("script_location", directory) + config.config_file_name = os.path.join(directory, "alembic.ini") + + command.init(config, directory, template="basic", package=package) + + @_catch_errors + def revision( + self, + directory: str | None = None, + message: str | None = None, + autogenerate: bool = False, + sql: bool = False, + head: str = "head", + splice: bool = False, + branch_label: RevIdType | None = None, + version_path: str | None = None, + rev_id: str | None = None, + ) -> None: + """Create a new revision file.""" + opts = ["autogenerate"] if autogenerate else None + + config = self.get_config(directory, opts=opts) + command.revision( + config, + message, + autogenerate=autogenerate, + sql=sql, + head=head, + splice=splice, + branch_label=branch_label, + version_path=version_path, + rev_id=rev_id, + ) + + @_catch_errors + def migrate( + self, + directory: str | None = None, + message: str | None = None, + sql: bool = False, + head: str = "head", + splice: bool = False, + branch_label: RevIdType | None = None, + version_path: str | None = None, + rev_id: str | None = None, + x_arg: str | None = None, + ) -> None: + """Alias for 'revision --autogenerate'""" + config = self.get_config( + directory, + opts=["autogenerate"], + x_arg=x_arg, + ) + command.revision( + config, + message, + autogenerate=True, + sql=sql, + head=head, + splice=splice, + branch_label=branch_label, + version_path=version_path, + rev_id=rev_id, + ) + + @_catch_errors + def edit(self, directory: str | None = None, revision: str = "current") -> None: + """Edit current revision.""" + config = self.get_config(directory) + command.edit(config, revision) + + @_catch_errors + def merge( + self, + directory: str | None = None, + revisions: RevIdType = "", + message: str | None = None, + branch_label: RevIdType | None = None, + rev_id: str | None = None, + ) -> None: + """Merge two revisions together. Creates a new migration file""" + config = self.get_config(directory) + command.merge( + config, revisions, message=message, branch_label=branch_label, rev_id=rev_id + ) + + @_catch_errors + def upgrade( + self, + directory: str | None = None, + revision: str = "head", + sql: bool = False, + tag: str | None = None, + x_arg: str | None = None, + ) -> None: + """Upgrade to a later version""" + config = self.get_config(directory, x_arg=x_arg) + command.upgrade(config, revision, sql=sql, tag=tag) + + @_catch_errors + def downgrade( + self, + directory: str | None = None, + revision: str = "-1", + sql: bool = False, + tag: str | None = None, + x_arg: str | None = None, + ) -> None: + """Revert to a previous version""" + config = self.get_config(directory, x_arg=x_arg) + if sql and revision == "-1": + revision = "head:-1" + command.downgrade(config, revision, sql=sql, tag=tag) + + @_catch_errors + def show(self, directory: str | None = None, revision: str = "head") -> None: + """Show the revision denoted by the given symbol.""" + config = self.get_config(directory) + command.show(config, revision) # type:ignore[no-untyped-call] + + @_catch_errors + def history( + self, + directory: str | None = None, + rev_range: t.Any = None, + verbose: bool = False, + indicate_current: bool = False, + ) -> None: + """List changeset scripts in chronological order.""" + config = self.get_config(directory) + command.history( + config, rev_range, verbose=verbose, indicate_current=indicate_current + ) + + @_catch_errors + def heads( + self, + directory: str | None = None, + verbose: bool = False, + resolve_dependencies: bool = False, + ) -> None: + """Show current available heads in the script directory""" + config = self.get_config(directory) + command.heads( # type:ignore[no-untyped-call] + config, verbose=verbose, resolve_dependencies=resolve_dependencies + ) + + @_catch_errors + def branches(self, directory: str | None = None, verbose: bool = False) -> None: + """Show current branch points""" + config = self.get_config(directory) + command.branches(config, verbose=verbose) # type:ignore[no-untyped-call] + + @_catch_errors + def current(self, directory: str | None = None, verbose: bool = False) -> None: + """Display the current revision for each database.""" + config = self.get_config(directory) + command.current(config, verbose=verbose) + + @_catch_errors + def stamp( + self, + app: App, + directory: str | None = None, + revision: str = "head", + sql: bool = False, + tag: t.Any = None, + ) -> None: + """'stamp' the revision table with the given revision; don't run any + migrations""" + config = self.get_config(app, directory) + command.stamp(config, revision, sql=sql, tag=tag) + + @_catch_errors + def check(self, app: App, directory: str | None = None) -> None: + """Check if there are any new operations to migrate""" + config = self.get_config(app, directory) + command.check(config) diff --git a/ellar_sqlalchemy/constant.py b/ellar_sqlalchemy/constant.py new file mode 100644 index 0000000..1931674 --- /dev/null +++ b/ellar_sqlalchemy/constant.py @@ -0,0 +1,11 @@ +import sqlalchemy.orm as sa_orm + +DATABASE_BIND_KEY = "database_bind_key" +DEFAULT_KEY = "default" +DATABASE_KEY = "__database__" +TABLE_KEY = "__table__" +ABSTRACT_KEY = "__abstract__" + + +class DeclarativeBasePlaceHolder(sa_orm.DeclarativeBase): + pass diff --git a/ellar_sqlalchemy/exceptions.py b/ellar_sqlalchemy/exceptions.py new file mode 100644 index 0000000..e69de29 diff --git a/ellar_sqlalchemy/migrations/__init__.py b/ellar_sqlalchemy/migrations/__init__.py new file mode 100644 index 0000000..34dad68 --- /dev/null +++ b/ellar_sqlalchemy/migrations/__init__.py @@ -0,0 +1,9 @@ +from .base import AlembicEnvMigrationBase +from .multiple import MultipleDatabaseAlembicEnvMigration +from .single import SingleDatabaseAlembicEnvMigration + +__all__ = [ + "SingleDatabaseAlembicEnvMigration", + "MultipleDatabaseAlembicEnvMigration", + "AlembicEnvMigrationBase", +] diff --git a/ellar_sqlalchemy/migrations/base.py b/ellar_sqlalchemy/migrations/base.py new file mode 100644 index 0000000..ff1a959 --- /dev/null +++ b/ellar_sqlalchemy/migrations/base.py @@ -0,0 +1,51 @@ +import typing as t +from abc import abstractmethod + +from alembic.runtime.environment import NameFilterType +from sqlalchemy.sql.schema import SchemaItem + +from ellar_sqlalchemy.services import EllarSQLAlchemyService +from ellar_sqlalchemy.types import RevisionArgs + +if t.TYPE_CHECKING: + from alembic.operations import MigrationScript + from alembic.runtime.environment import EnvironmentContext + from alembic.runtime.migration import MigrationContext + + +class AlembicEnvMigrationBase: + def __init__(self, db_service: EllarSQLAlchemyService) -> None: + self.db_service = db_service + self.use_two_phase = db_service.migration_options.use_two_phase + + def include_object( + self, + obj: SchemaItem, + name: t.Optional[str], + type_: NameFilterType, + reflected: bool, + compare_to: t.Optional[SchemaItem], + ) -> bool: + # If you want to ignore things like these, set the following as a class attribute + # __table_args__ = {"info": {"skip_autogen": True}} + if obj.info.get("skip_autogen", False): + return False + + return True + + @abstractmethod + def default_process_revision_directives( + self, + context: "MigrationContext", + revision: RevisionArgs, + directives: t.List["MigrationScript"], + ) -> t.Any: + pass + + @abstractmethod + def run_migrations_offline(self, context: "EnvironmentContext") -> None: + pass + + @abstractmethod + async def run_migrations_online(self, context: "EnvironmentContext") -> None: + pass diff --git a/ellar_sqlalchemy/migrations/multiple.py b/ellar_sqlalchemy/migrations/multiple.py new file mode 100644 index 0000000..f899f26 --- /dev/null +++ b/ellar_sqlalchemy/migrations/multiple.py @@ -0,0 +1,212 @@ +import logging +import typing as t +from dataclasses import dataclass + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine + +from ellar_sqlalchemy.model.database_binds import get_database_bind +from ellar_sqlalchemy.types import RevisionArgs + +from .base import AlembicEnvMigrationBase + +if t.TYPE_CHECKING: + from alembic.operations import MigrationScript + from alembic.runtime.environment import EnvironmentContext + from alembic.runtime.migration import MigrationContext + + +logger = logging.getLogger("alembic.env") + + +@dataclass +class DatabaseInfo: + name: str + metadata: sa.MetaData + engine: t.Union[sa.Engine, AsyncEngine] + connection: t.Union[sa.Connection, AsyncConnection] + use_two_phase: bool = False + + _transaction: t.Optional[t.Union[sa.TwoPhaseTransaction, sa.RootTransaction]] = None + _sync_connection: t.Optional[sa.Connection] = None + + def sync_connection(self) -> sa.Connection: + if not self._sync_connection: + self._sync_connection = getattr( + self.connection, "sync_connection", self.connection + ) + assert self._sync_connection is not None + return self._sync_connection + + def get_transactions(self) -> t.Union[sa.TwoPhaseTransaction, sa.RootTransaction]: + if not self._transaction: + if self.use_two_phase: + self._transaction = self.sync_connection().begin_twophase() + else: + self._transaction = self.sync_connection().begin() + assert self._transaction is not None + return self._transaction + + +class MultipleDatabaseAlembicEnvMigration(AlembicEnvMigrationBase): + """ + Migration Class for Multiple Database Configuration + for both asynchronous and synchronous database engine dialect + """ + + def default_process_revision_directives( + self, + context: "MigrationContext", + revision: RevisionArgs, + directives: t.List["MigrationScript"], + ) -> None: + if getattr(context.config.cmd_opts, "autogenerate", False): + script = directives[0] + + if len(script.upgrade_ops_list) == len(self.db_service.engines.keys()): + # wait till there is a full check of all databases before removing empty operations + + for upgrade_ops in list(script.upgrade_ops_list): + if upgrade_ops.is_empty(): + script.upgrade_ops_list.remove(upgrade_ops) + + for downgrade_ops in list(script.downgrade_ops_list): + if downgrade_ops.is_empty(): + script.downgrade_ops_list.remove(downgrade_ops) + + if ( + len(script.upgrade_ops_list) == 0 + and len(script.downgrade_ops_list) == 0 + ): + directives[:] = [] + logger.info("No changes in schema detected.") + + def run_migrations_offline(self, context: "EnvironmentContext") -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation, + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + # for --sql use case, run migrations for each URL into + # individual files. + + for key, engine in self.db_service.engines.items(): + logger.info("Migrating database %s" % key) + + url = str(engine.url).replace("%", "%%") + metadata = get_database_bind(key, certain=True) + + file_ = "%s.sql" % key + logger.info("Writing output to %s" % file_) + with open(file_, "w") as buffer: + context.configure( + url=url, + output_buffer=buffer, + target_metadata=metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + # If you want to ignore things like these, set the following as a class attribute + # __table_args__ = {"info": {"skip_autogen": True}} + include_object=self.include_object, + # detecting type changes + # compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations(engine_name=key) + + def _migration_action( + self, _: t.Any, db_infos: t.List[DatabaseInfo], context: "EnvironmentContext" + ) -> None: + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + conf_args = { + "process_revision_directives": self.default_process_revision_directives + } + # conf_args = current_app.extensions['migrate'].configure_args + # if conf_args.get("process_revision_directives") is None: + # conf_args["process_revision_directives"] = process_revision_directives + + try: + for db_info in db_infos: + context.configure( + connection=db_info.sync_connection(), + upgrade_token="%s_upgrades" % db_info.name, + downgrade_token="%s_downgrades" % db_info.name, + target_metadata=db_info.metadata, + **conf_args, + ) + + context.run_migrations(engine_name=db_info.name) + + if self.use_two_phase: + for db_info in db_infos: + db_info.get_transactions().prepare() # type:ignore[attr-defined] + + for db_info in db_infos: + db_info.get_transactions().commit() + + except Exception as ex: + for db_info in db_infos: + db_info.get_transactions().rollback() + + logger.error(ex) + raise ex + finally: + for db_info in db_infos: + db_info.sync_connection().close() + + async def _check_if_coroutine(self, func: t.Any) -> t.Any: + if isinstance(func, t.Coroutine): + return await func + return func + + async def _compute_engine_info(self) -> t.List[DatabaseInfo]: + res = [] + + for key, engine in self.db_service.engines.items(): + metadata = get_database_bind(key, certain=True) + + if engine.dialect.is_async: + async_engine = AsyncEngine(engine) + connection = async_engine.connect() + connection = await connection.start() + engine = async_engine # type:ignore[assignment] + else: + connection = engine.connect() # type:ignore[assignment] + + database_info = DatabaseInfo( + engine=engine, + metadata=metadata, + connection=connection, + name=key, + use_two_phase=self.use_two_phase, + ) + database_info.get_transactions() + res.append(database_info) + return res + + async def run_migrations_online(self, context: "EnvironmentContext") -> None: + # for the direct-to-DB use case, start a transaction on all + # engines, then run all migrations, then commit all transactions. + + database_infos = await self._compute_engine_info() + async_db_info_filter = [ + db_info for db_info in database_infos if db_info.engine.dialect.is_async + ] + try: + if len(async_db_info_filter) > 0: + await async_db_info_filter[0].connection.run_sync( + self._migration_action, database_infos, context + ) + else: + self._migration_action(None, database_infos, context) + finally: + for database_info_ in database_infos: + await self._check_if_coroutine(database_info_.connection.close()) diff --git a/ellar_sqlalchemy/migrations/single.py b/ellar_sqlalchemy/migrations/single.py new file mode 100644 index 0000000..2923c66 --- /dev/null +++ b/ellar_sqlalchemy/migrations/single.py @@ -0,0 +1,113 @@ +import functools +import logging +import typing as t + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncEngine + +from ellar_sqlalchemy.model.database_binds import get_database_bind +from ellar_sqlalchemy.types import RevisionArgs + +from .base import AlembicEnvMigrationBase + +if t.TYPE_CHECKING: + from alembic.operations import MigrationScript + from alembic.runtime.environment import EnvironmentContext + from alembic.runtime.migration import MigrationContext + + +logger = logging.getLogger("alembic.env") + + +class SingleDatabaseAlembicEnvMigration(AlembicEnvMigrationBase): + """ + Migration Class for a Single Database Configuration + for both asynchronous and synchronous database engine dialect + """ + + def default_process_revision_directives( + self, + context: "MigrationContext", + revision: RevisionArgs, + directives: t.List["MigrationScript"], + ) -> None: + if getattr(context.config.cmd_opts, "autogenerate", False): + script = directives[0] + if script.upgrade_ops.is_empty(): + directives[:] = [] + logger.info("No changes in schema detected.") + + def run_migrations_offline(self, context: "EnvironmentContext") -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + + key, engine = self.db_service.engines.popitem() + metadata = get_database_bind(key, certain=True) + + context.configure( + url=str(engine.url).replace("%", "%%"), + target_metadata=metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + # If you want to ignore things like these, set the following as a class attribute + # __table_args__ = {"info": {"skip_autogen": True}} + include_object=self.include_object, + # detecting type changes + # compare_type=True, + ) + + with context.begin_transaction(): + context.run_migrations() + + def _migration_action( + self, + connection: sa.Connection, + metadata: sa.MetaData, + context: "EnvironmentContext", + ) -> None: + # this callback is used to prevent an auto-migration from being generated + # when there are no changes to the schema + # reference: http://alembic.zzzcomputing.com/en/latest/cookbook.html + conf_args = { + "process_revision_directives": self.default_process_revision_directives + } + # conf_args = current_app.extensions['migrate'].configure_args + # if conf_args.get("process_revision_directives") is None: + # conf_args["process_revision_directives"] = process_revision_directives + + context.configure(connection=connection, target_metadata=metadata, **conf_args) + + with context.begin_transaction(): + context.run_migrations() + + async def run_migrations_online(self, context: "EnvironmentContext") -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + key, engine = self.db_service.engines.popitem() + metadata = get_database_bind(key, certain=True) + + migration_action_partial = functools.partial( + self._migration_action, metadata=metadata, context=context + ) + + if engine.dialect.is_async: + async_engine = AsyncEngine(engine) + async with async_engine.connect() as connection: + await connection.run_sync(migration_action_partial) + else: + with engine.connect() as connection: + migration_action_partial(connection) diff --git a/ellar_sqlalchemy/model/__init__.py b/ellar_sqlalchemy/model/__init__.py new file mode 100644 index 0000000..5225e55 --- /dev/null +++ b/ellar_sqlalchemy/model/__init__.py @@ -0,0 +1,10 @@ +from .base import Model +from .typeDecorator import GUID, GenericIP +from .utils import make_metadata + +__all__ = [ + "Model", + "make_metadata", + "GUID", + "GenericIP", +] diff --git a/ellar_sqlalchemy/model/base.py b/ellar_sqlalchemy/model/base.py new file mode 100644 index 0000000..b5bd940 --- /dev/null +++ b/ellar_sqlalchemy/model/base.py @@ -0,0 +1,121 @@ +import types +import typing as t + +import sqlalchemy.orm as sa_orm +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import DeclarativeBase + +from ellar_sqlalchemy.constant import ( + DATABASE_BIND_KEY, + DEFAULT_KEY, +) + +from .database_binds import get_database_bind, has_database_bind, update_database_binds +from .mixins import ( + DatabaseBindKeyMixin, + ModelDataExportMixin, + ModelTrackMixin, + NameMetaMixin, +) + +SQLAlchemyDefaultBase = None + + +def _model_as_base( + name: str, bases: t.Tuple[t.Any, ...], namespace: t.Dict[str, t.Any] +) -> t.Type["Model"]: + global SQLAlchemyDefaultBase + + if SQLAlchemyDefaultBase is None: + declarative_bases = [ + b + for b in bases + if issubclass(b, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) + ] + + def get_session(cls: t.Type[Model]) -> None: + raise Exception("EllarSQLAlchemyService is not ready") + + namespace.update( + get_db_session=getattr(Model, "get_session", classmethod(get_session)), + skip_default_base_check=True, + ) + + model = types.new_class( + f"{name}", + ( + DatabaseBindKeyMixin, + NameMetaMixin, + ModelTrackMixin, + ModelDataExportMixin, + Model, + *declarative_bases, + sa_orm.DeclarativeBase, + ), + {}, + lambda ns: ns.update(namespace), + ) + model = t.cast(t.Type[Model], model) + SQLAlchemyDefaultBase = model + + if not has_database_bind(DEFAULT_KEY): + # Use the model's metadata as the default metadata. + model.metadata.info[DATABASE_BIND_KEY] = DEFAULT_KEY + update_database_binds(DEFAULT_KEY, model.metadata) + else: + # Use the passed in default metadata as the model's metadata. + model.metadata = get_database_bind(DEFAULT_KEY, certain=True) + return model + else: + return SQLAlchemyDefaultBase + + +class ModelMeta(type(DeclarativeBase)): # type:ignore[misc] + def __new__( + mcs, + name: str, + bases: t.Tuple[t.Any, ...], + namespace: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> t.Type[t.Union["Model", t.Any]]: + if bases == () and name == "Model": + return type.__new__(mcs, name, tuple(bases), namespace, **kwargs) + + if "as_base" in kwargs: + return _model_as_base(name, bases, namespace) + + _bases = list(bases) + + skip_default_base_check = False + if "skip_default_base_check" in namespace: + skip_default_base_check = namespace.pop("skip_default_base_check") + + if not skip_default_base_check: + if SQLAlchemyDefaultBase is None: + raise Exception( + "EllarSQLAlchemy Default Declarative Base has not been configured." + "\nPlease call `configure_model_declarative_base` before ORM Model construction" + " or Use EllarSQLAlchemy Service" + ) + elif SQLAlchemyDefaultBase and SQLAlchemyDefaultBase not in _bases: + _bases = [SQLAlchemyDefaultBase, *_bases] + + return super().__new__(mcs, name, (*_bases,), namespace, **kwargs) # type:ignore[no-any-return] + + +class Model(metaclass=ModelMeta): + __database__: str = "default" + + if t.TYPE_CHECKING: + + def __init__(self, **kwargs: t.Any) -> None: + ... + + @classmethod + def get_db_session( + cls, + ) -> t.Union[sa_orm.Session, AsyncSession, t.Any]: + ... + + def dict(self, exclude: t.Optional[t.Set[str]] = None) -> t.Dict[str, t.Any]: + ... diff --git a/ellar_sqlalchemy/model/database_binds.py b/ellar_sqlalchemy/model/database_binds.py new file mode 100644 index 0000000..a55c32a --- /dev/null +++ b/ellar_sqlalchemy/model/database_binds.py @@ -0,0 +1,23 @@ +import typing as t + +import sqlalchemy as sa + +__model_database_metadata__: t.Dict[str, sa.MetaData] = {} + + +def update_database_binds(key: str, value: sa.MetaData) -> None: + __model_database_metadata__[key] = value + + +def get_database_binds() -> t.Dict[str, sa.MetaData]: + return __model_database_metadata__.copy() + + +def get_database_bind(key: str, certain: bool = False) -> sa.MetaData: + if certain: + return __model_database_metadata__[key] + return __model_database_metadata__.get(key) # type:ignore[return-value] + + +def has_database_bind(key: str) -> bool: + return key in __model_database_metadata__ diff --git a/ellar_sqlalchemy/model/mixins.py b/ellar_sqlalchemy/model/mixins.py new file mode 100644 index 0000000..ded7d01 --- /dev/null +++ b/ellar_sqlalchemy/model/mixins.py @@ -0,0 +1,79 @@ +import typing as t + +import sqlalchemy as sa + +from ellar_sqlalchemy.constant import ABSTRACT_KEY, DATABASE_KEY, DEFAULT_KEY, TABLE_KEY + +from .utils import camel_to_snake_case, make_metadata, should_set_table_name + +if t.TYPE_CHECKING: + from .base import Model + +__ellar_sqlalchemy_models__: t.Dict[str, t.Type["Model"]] = {} + + +def get_registered_models() -> t.Dict[str, t.Type["Model"]]: + return __ellar_sqlalchemy_models__.copy() + + +class NameMetaMixin: + metadata: sa.MetaData + __tablename__: str + __table__: sa.Table + + def __init_subclass__(cls, **kwargs: t.Dict[str, t.Any]) -> None: + if should_set_table_name(cls): + cls.__tablename__ = camel_to_snake_case(cls.__name__) + + super().__init_subclass__(**kwargs) + + +class DatabaseBindKeyMixin: + metadata: sa.MetaData + __dnd__ = "Ellar" + + def __init_subclass__(cls, **kwargs: t.Dict[str, t.Any]) -> None: + if not ("metadata" in cls.__dict__ or TABLE_KEY in cls.__dict__) and hasattr( + cls, DATABASE_KEY + ): + database_bind_key = getattr(cls, DATABASE_KEY, DEFAULT_KEY) + parent_metadata = getattr(cls, "metadata", None) + metadata = make_metadata(database_bind_key) + + if metadata is not parent_metadata: + cls.metadata = metadata + + super().__init_subclass__(**kwargs) + + +class ModelTrackMixin: + metadata: sa.MetaData + + def __init_subclass__(cls, **kwargs: t.Dict[str, t.Any]) -> None: + super().__init_subclass__(**kwargs) + + if TABLE_KEY in cls.__dict__ and ABSTRACT_KEY not in cls.__dict__: + __ellar_sqlalchemy_models__[str(cls)] = cls # type:ignore[assignment] + + +class ModelDataExportMixin: + def __repr__(self) -> str: + columns = ", ".join( + [ + f"{k}={repr(v)}" + for k, v in self.__dict__.items() + if not k.startswith("_") + ] + ) + return f"<{self.__class__.__name__}({columns})>" + + def dict(self, exclude: t.Optional[t.Set[str]] = None) -> t.Dict[str, t.Any]: + # TODO: implement advance exclude and include that goes deep into relationships too + _exclude: t.Set[str] = set() if not exclude else exclude + + tuple_generator = ( + (k, v) + for k, v in self.__dict__.items() + if k not in _exclude and not k.startswith("_sa") + ) + return dict(tuple_generator) diff --git a/ellar_sqlalchemy/model/typeDecorator/__init__.py b/ellar_sqlalchemy/model/typeDecorator/__init__.py new file mode 100644 index 0000000..daddaa9 --- /dev/null +++ b/ellar_sqlalchemy/model/typeDecorator/__init__.py @@ -0,0 +1,13 @@ +# from .file import FileField +from .guid import GUID +from .ipaddress import GenericIP + +# from .image import CroppingDetails, ImageFileField + +__all__ = [ + "GUID", + "GenericIP", + # "CroppingDetails", + # "FileField", + # "ImageFileField", +] diff --git a/ellar_sqlalchemy/model/typeDecorator/exceptions.py.ellar b/ellar_sqlalchemy/model/typeDecorator/exceptions.py.ellar new file mode 100644 index 0000000..779b32a --- /dev/null +++ b/ellar_sqlalchemy/model/typeDecorator/exceptions.py.ellar @@ -0,0 +1,25 @@ +class ContentTypeValidationError(Exception): + def __init__(self, content_type=None, valid_content_types=None): + + if content_type is None: + message = "Content type is not provided. " + else: + message = "Content type is not supported %s. " % content_type + + if valid_content_types: + message += "Valid options are: %s" % ", ".join(valid_content_types) + + super().__init__(message) + + +class InvalidFileError(Exception): + pass + + +class InvalidImageOperationError(Exception): + pass + + +class MaximumAllowedFileLengthError(Exception): + def __init__(self, max_length: int): + super().__init__("Cannot store files larger than: %d bytes" % max_length) diff --git a/ellar_sqlalchemy/model/typeDecorator/file.py.ellar b/ellar_sqlalchemy/model/typeDecorator/file.py.ellar new file mode 100644 index 0000000..4a8a83a --- /dev/null +++ b/ellar_sqlalchemy/model/typeDecorator/file.py.ellar @@ -0,0 +1,208 @@ +import json +import time +import typing as t +import uuid + +from sqlalchemy import JSON, String, TypeDecorator +from starlette.datastructures import UploadFile + +from fullview_trader.core.storage import BaseStorage +from fullview_trader.core.storage.utils import get_length, get_valid_filename + +from .exceptions import ( + ContentTypeValidationError, + InvalidFileError, + MaximumAllowedFileLengthError, +) +from .mimetypes import guess_extension, magic_mime_from_buffer + +T = t.TypeVar("T", bound="FileObject") + + +class FileObject: + def __init__( + self, + *, + storage: BaseStorage, + original_filename: str, + uploaded_on: int, + content_type: str, + saved_filename: str, + extension: str, + file_size: int, + ) -> None: + self._storage = storage + self.original_filename = original_filename + self.uploaded_on = uploaded_on + self.content_type = content_type + self.filename = saved_filename + self.extension = extension + self.file_size = file_size + + def locate(self) -> str: + return self._storage.locate(self.filename) + + def open(self) -> t.IO: + return self._storage.open(self.filename) + + def to_dict(self) -> dict: + return { + "original_filename": self.original_filename, + "uploaded_on": self.uploaded_on, + "content_type": self.content_type, + "extension": self.extension, + "file_size": self.file_size, + "saved_filename": self.filename, + "service_name": self._storage.service_name(), + } + + def __str__(self) -> str: + return f"filename={self.filename}, content_type={self.content_type}, file_size={self.file_size}" + + def __repr__(self) -> str: + return str(self) + + +class FileFieldBase(t.Generic[T]): + FileObject: t.Type[T] = FileObject + + def load_dialect_impl(self, dialect): + if dialect.name == "sqlite": + return dialect.type_descriptor(String()) + else: + return dialect.type_descriptor(JSON()) + + def __init__( + self, + *args: t.Any, + storage: BaseStorage = None, + allowed_content_types: t.List[str] = None, + max_size: t.Optional[int] = None, + **kwargs: t.Any, + ): + if allowed_content_types is None: + allowed_content_types = [] + super().__init__(*args, **kwargs) + + self._storage = storage + self._allowed_content_types = allowed_content_types + self._max_size = max_size + + def validate(self, file: T) -> None: + if self._allowed_content_types and file.content_type not in self._allowed_content_types: + raise ContentTypeValidationError(file.content_type, self._allowed_content_types) + if self._max_size and file.file_size > self._max_size: + raise MaximumAllowedFileLengthError(self._max_size) + + self._storage.validate_file_name(file.filename) + + def load_from_str(self, data: str) -> T: + data_dict = t.cast(t.Dict, json.loads(data)) + return self.load(data_dict) + + def load(self, data: dict) -> T: + if "service_name" in data: + data.pop("service_name") + return self.FileObject(storage=self._storage, **data) + + def _guess_content_type(self, file: t.IO) -> str: + content = file.read(1024) + + if isinstance(content, str): + content = str.encode(content) + + file.seek(0) + + return magic_mime_from_buffer(content) + + def get_extra_file_initialization_context(self, file: UploadFile) -> dict: + return {} + + def convert_to_file_object(self, file: UploadFile) -> T: + unique_name = str(uuid.uuid4()) + + original_filename = file.filename + + # use python magic to get the content type + content_type = self._guess_content_type(file.file) + extension = guess_extension(content_type) + + file_size = get_length(file.file) + saved_filename = f"{original_filename[:-len(extension)]}_{unique_name[:-8]}{extension}" + saved_filename = get_valid_filename(saved_filename) + + init_kwargs = self.get_extra_file_initialization_context(file) + init_kwargs.update( + storage=self._storage, + original_filename=original_filename, + uploaded_on=int(time.time()), + content_type=content_type, + extension=extension, + file_size=file_size, + saved_filename=saved_filename, + ) + return self.FileObject(**init_kwargs) + + def process_bind_param_action( + self, value: t.Any, dialect: t.Any + ) -> t.Optional[t.Union[str, dict]]: + if value is None: + return value + + if isinstance(value, UploadFile): + value.file.seek(0) # make sure we are always at the beginning + file_obj = self.convert_to_file_object(value) + self.validate(file_obj) + + self._storage.put(file_obj.filename, value.file) + value = file_obj + + if isinstance(value, FileObject): + if dialect.name == "sqlite": + return json.dumps(value.to_dict()) + return value.to_dict() + + raise InvalidFileError() + + def process_result_value_action( + self, value: t.Any, dialect: t.Any + ) -> t.Optional[t.Union[str, dict]]: + if value is None: + return value + else: + if isinstance(value, str): + value = self.load_from_str(value) + elif isinstance(value, dict): + value = self.load(value) + return value + + +class FileField(FileFieldBase[FileObject], TypeDecorator): + """ + Provide SqlAlchemy TypeDecorator for saving files + ## Basic Usage + + fs = FileSystemStorage('path/to/save/files') + + class MyTable(Base): + image: FileField.FileObject = sa.Column( + ImageFileField(storage=fs, max_size=10*MB, allowed_content_type=["application/pdf"]), + nullable=True + ) + + def route(file: File[UploadFile]): + session = SessionLocal() + my_table_model = MyTable(image=file) + session.add(my_table_model) + session.commit() + return my_table_model.image.to_dict() + + """ + + impl = JSON + + def process_bind_param(self, value, dialect): + return self.process_bind_param_action(value, dialect) + + def process_result_value(self, value, dialect): + return self.process_result_value_action(value, dialect) diff --git a/ellar_sqlalchemy/model/typeDecorator/guid.py b/ellar_sqlalchemy/model/typeDecorator/guid.py new file mode 100644 index 0000000..e745753 --- /dev/null +++ b/ellar_sqlalchemy/model/typeDecorator/guid.py @@ -0,0 +1,47 @@ +import typing as t +import uuid + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import UUID +from sqlalchemy.types import CHAR + + +class GUID(sa.TypeDecorator): # type: ignore[type-arg] + """Platform-independent GUID type. + + Uses PostgreSQL's UUID type, otherwise uses + CHAR(32), storing as stringified hex values. + + """ + + impl = CHAR + + def load_dialect_impl(self, dialect: sa.Dialect) -> t.Any: + if dialect.name == "postgresql": + return dialect.type_descriptor(UUID()) + else: + return dialect.type_descriptor(CHAR(32)) + + def process_bind_param( + self, value: t.Optional[t.Any], dialect: sa.Dialect + ) -> t.Any: + if value is None: + return value + elif dialect.name == "postgresql": + return str(value) + else: + if not isinstance(value, uuid.UUID): + return "%.32x" % uuid.UUID(value).int + else: + # hexstring + return "%.32x" % value.int + + def process_result_value( + self, value: t.Optional[t.Any], dialect: sa.Dialect + ) -> t.Any: + if value is None: + return value + else: + if not isinstance(value, uuid.UUID): + value = uuid.UUID(value) + return value diff --git a/ellar_sqlalchemy/model/typeDecorator/image.py.ellar b/ellar_sqlalchemy/model/typeDecorator/image.py.ellar new file mode 100644 index 0000000..b8b298c --- /dev/null +++ b/ellar_sqlalchemy/model/typeDecorator/image.py.ellar @@ -0,0 +1,151 @@ +import typing as t +from dataclasses import dataclass +from io import SEEK_END, BytesIO + +from sqlalchemy import JSON, TypeDecorator +from starlette.datastructures import UploadFile + +from .exceptions import InvalidImageOperationError + +try: + from PIL import Image +except ImportError as im_ex: # pragma: no cover + raise Exception("Pillow package is required. Use `pip install Pillow`.") from im_ex + +from fullview_trader.core.storage import BaseStorage + +from .file import FileFieldBase, FileObject + + +@dataclass +class CroppingDetails: + x: int + y: int + height: int + width: int + + +class ImageFileObject(FileObject): + def __init__(self, *, height: float, width: float, **kwargs: t.Any) -> None: + super().__init__(**kwargs) + self.height = height + self.width = width + + def to_dict(self) -> dict: + data = super().to_dict() + data.update(height=self.height, width=self.width) + return data + + +class ImageFileField(FileFieldBase[ImageFileObject], TypeDecorator): + """ + Provide SqlAlchemy TypeDecorator for Image files + ## Basic Usage + + class MyTable(Base): + image: + ImageFileField.FileObject = sa.Column(ImageFileField(storage=FileSystemStorage('path/to/save/files', + max_size=10*MB), nullable=True) + + def route(file: File[UploadFile]): + session = SessionLocal() + my_table_model = MyTable(image=file) + session.add(my_table_model) + session.commit() + return my_table_model.image.to_dict() + + ## Cropping + Image file also provides cropping capabilities which can be defined in the column or when saving the image data. + + fs = FileSystemStorage('path/to/save/files') + class MyTable(Base): + image = sa.Column(ImageFileField(storage=fs, crop=CroppingDetails(x=100, y=200, height=400, width=400)), nullable=True) + + OR + def route(file: File[UploadFile]): + session = SessionLocal() + my_table_model = MyTable( + image=(file, CroppingDetails(x=100, y=200, height=400, width=400)), + ) + + """ + + impl = JSON + FileObject = ImageFileObject + + def __init__( + self, + *args: t.Any, + storage: BaseStorage, + max_size: t.Optional[int] = None, + crop: t.Optional[CroppingDetails] = None, + **kwargs: t.Any + ): + kwargs.setdefault("allowed_content_types", ["image/jpeg", "image/png"]) + super().__init__(*args, storage=storage, max_size=max_size, **kwargs) + self.crop = crop + + def process_bind_param(self, value, dialect): + return self.process_bind_param_action(value, dialect) + + def process_result_value(self, value, dialect): + return self.process_result_value_action(value, dialect) + + def get_extra_file_initialization_context(self, file: UploadFile) -> dict: + with Image.open(file.file) as image: + width, height = image.size + return {"width": width, "height": height} + + def crop_image_with_box_sizing( + self, file: UploadFile, crop: t.Optional[CroppingDetails] = None + ) -> UploadFile: + crop_info = crop or self.crop + img = Image.open(file.file) + (height, width, x, y,) = ( + crop_info.height, + crop_info.width, + crop_info.x, + crop_info.y, + ) + left = x + top = y + right = x + width + bottom = y + height + + crop_box = (left, top, right, bottom) + + img_res = img.crop(box=crop_box) + temp_thumb = BytesIO() + img_res.save(temp_thumb, img.format) + # Go to the end of the stream. + temp_thumb.seek(0, SEEK_END) + + # Get the current position, which is now at the end. + # We can use this as the size. + size = temp_thumb.tell() + temp_thumb.seek(0) + + content = UploadFile( + file=temp_thumb, filename=file.filename, size=size, headers=file.headers + ) + return content + + def process_bind_param_action( + self, value: t.Any, dialect: t.Any + ) -> t.Optional[t.Union[str, dict]]: + if isinstance(value, tuple): + file, crop_data = value + if not isinstance(file, UploadFile) or not isinstance(crop_data, CroppingDetails): + raise InvalidImageOperationError( + "Invalid data was provided for ImageFileField. " + "Accept values: UploadFile or (UploadFile, CroppingDetails)" + ) + new_file = self.crop_image_with_box_sizing(file=file, crop=crop_data) + return super().process_bind_param_action(new_file, dialect) + + if isinstance(value, UploadFile): + if self.crop: + return super().process_bind_param_action( + self.crop_image_with_box_sizing(value), dialect + ) + return super().process_bind_param_action(value, dialect) diff --git a/ellar_sqlalchemy/model/typeDecorator/ipaddress.py b/ellar_sqlalchemy/model/typeDecorator/ipaddress.py new file mode 100644 index 0000000..8c29f91 --- /dev/null +++ b/ellar_sqlalchemy/model/typeDecorator/ipaddress.py @@ -0,0 +1,39 @@ +import ipaddress +import typing as t + +import sqlalchemy as sa +import sqlalchemy.dialects as sa_dialects + + +class GenericIP(sa.TypeDecorator): # type:ignore[type-arg] + """ + Platform-independent IP Address type. + + Uses PostgreSQL's INET type, otherwise uses + CHAR(45), storing as stringified values. + """ + + impl = sa.CHAR + cache_ok = True + + def load_dialect_impl(self, dialect: sa.Dialect) -> t.Any: + if dialect.name == "postgresql": + return dialect.type_descriptor(sa_dialects.postgresql.INET()) # type:ignore[attr-defined] + else: + return dialect.type_descriptor(sa.CHAR(45)) + + def process_bind_param( + self, value: t.Optional[t.Any], dialect: sa.Dialect + ) -> t.Any: + if value is not None: + return str(value) + + def process_result_value( + self, value: t.Optional[t.Any], dialect: sa.Dialect + ) -> t.Any: + if value is None: + return value + + if not isinstance(value, (ipaddress.IPv4Address, ipaddress.IPv6Address)): + value = ipaddress.ip_address(value) + return value diff --git a/ellar_sqlalchemy/model/typeDecorator/mimetypes.py.ellar b/ellar_sqlalchemy/model/typeDecorator/mimetypes.py.ellar new file mode 100644 index 0000000..f287a76 --- /dev/null +++ b/ellar_sqlalchemy/model/typeDecorator/mimetypes.py.ellar @@ -0,0 +1,21 @@ +import mimetypes as mdb +import typing + +import magic + + +def magic_mime_from_buffer(buffer: bytes) -> str: + return magic.from_buffer(buffer, mime=True) + + +def guess_extension(mimetype: str) -> typing.Optional[str]: + """ + Due to the python bugs 'image/jpeg' overridden: + - https://bugs.python.org/issue4963 + - https://bugs.python.org/issue1043134 + - https://bugs.python.org/issue6626#msg91205 + """ + + if mimetype == "image/jpeg": + return ".jpeg" + return mdb.guess_extension(mimetype) diff --git a/ellar_sqlalchemy/model/utils.py b/ellar_sqlalchemy/model/utils.py new file mode 100644 index 0000000..5e5a796 --- /dev/null +++ b/ellar_sqlalchemy/model/utils.py @@ -0,0 +1,65 @@ +import re + +import sqlalchemy as sa +import sqlalchemy.orm as sa_orm + +from ellar_sqlalchemy.constant import DATABASE_BIND_KEY, DEFAULT_KEY + +from .database_binds import get_database_bind, has_database_bind, update_database_binds + + +def make_metadata(database_key: str) -> sa.MetaData: + if has_database_bind(database_key): + return get_database_bind(database_key, certain=True) + + if database_key is not None: + # Copy the naming convention from the default metadata. + naming_convention = make_metadata(DEFAULT_KEY).naming_convention + else: + naming_convention = None + + # Set the bind key in info to be used by session.get_bind. + metadata = sa.MetaData( + naming_convention=naming_convention, info={DATABASE_BIND_KEY: database_key} + ) + update_database_binds(database_key, metadata) + return metadata + + +def camel_to_snake_case(name: str) -> str: + """Convert a ``CamelCase`` name to ``snake_case``.""" + name = re.sub(r"((?<=[a-z0-9])[A-Z]|(?!^)[A-Z](?=[a-z]))", r"_\1", name) + return name.lower().lstrip("_") + + +def should_set_table_name(cls: type) -> bool: + if ( + cls.__dict__.get("__abstract__", False) + or ( + not issubclass(cls, (sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta)) + and not any(isinstance(b, sa_orm.DeclarativeMeta) for b in cls.__mro__[1:]) + ) + or any( + (b is sa_orm.DeclarativeBase or b is sa_orm.DeclarativeBaseNoMeta) + for b in cls.__bases__ + ) + ): + return False + + for base in cls.__mro__: + if "__tablename__" not in base.__dict__: + continue + + if isinstance(base.__dict__["__tablename__"], sa_orm.declared_attr): + return False + + return not ( + base is cls + or base.__dict__.get("__abstract__", False) + or not ( + isinstance(base, sa_orm.decl_api.DeclarativeAttributeIntercept) + or issubclass(base, sa_orm.DeclarativeBaseNoMeta) + ) + ) + + return True diff --git a/ellar_sqlalchemy/module.py b/ellar_sqlalchemy/module.py new file mode 100644 index 0000000..20f69d5 --- /dev/null +++ b/ellar_sqlalchemy/module.py @@ -0,0 +1,146 @@ +import functools +import typing as t + +import sqlalchemy as sa +from ellar.app import current_injector +from ellar.common import IApplicationShutdown, IModuleSetup, Module +from ellar.common.utils.importer import get_main_directory_by_stack +from ellar.core import Config, DynamicModule, ModuleBase, ModuleSetup +from ellar.di import ProviderConfig, request_scope +from sqlalchemy.ext.asyncio import ( + AsyncEngine, + AsyncSession, +) +from sqlalchemy.orm import Session + +from ellar_sqlalchemy.services import EllarSQLAlchemyService + +from .cli import DBCommands +from .schemas import MigrationOption, SQLAlchemyConfig + + +@Module(commands=[DBCommands]) +class EllarSQLAlchemyModule(ModuleBase, IModuleSetup, IApplicationShutdown): + async def on_shutdown(self) -> None: + db_service = current_injector.get(EllarSQLAlchemyService) + db_service.session_factory.remove() + + @classmethod + def setup( + cls, + *, + databases: t.Union[str, t.Dict[str, t.Any]], + migration_options: t.Union[t.Dict[str, t.Any], MigrationOption], + session_options: t.Optional[t.Dict[str, t.Any]] = None, + engine_options: t.Optional[t.Dict[str, t.Any]] = None, + models: t.Optional[t.List[str]] = None, + echo: bool = False, + ) -> "DynamicModule": + """ + Configures EllarSQLAlchemyModule and setup required providers. + """ + root_path = get_main_directory_by_stack("__main__", stack_level=2) + if isinstance(migration_options, dict): + migration_options.update( + directory=get_main_directory_by_stack( + migration_options.get("directory", "__main__/migrations"), + stack_level=2, + from_dir=root_path, + ) + ) + if isinstance(migration_options, MigrationOption): + migration_options.directory = get_main_directory_by_stack( + migration_options.directory, stack_level=2, from_dir=root_path + ) + migration_options = migration_options.dict() + + schema = SQLAlchemyConfig.model_validate( + { + "databases": databases, + "engine_options": engine_options, + "echo": echo, + "models": models, + "migration_options": migration_options, + "root_path": root_path, + }, + from_attributes=True, + ) + return cls.__setup_module(schema) + + @classmethod + def __setup_module(cls, sql_alchemy_config: SQLAlchemyConfig) -> DynamicModule: + db_service = EllarSQLAlchemyService( + databases=sql_alchemy_config.databases, + common_engine_options=sql_alchemy_config.engine_options, + common_session_options=sql_alchemy_config.session_options, + echo=sql_alchemy_config.echo, + models=sql_alchemy_config.models, + root_path=sql_alchemy_config.root_path, + migration_options=sql_alchemy_config.migration_options, + ) + providers: t.List[t.Any] = [] + + if db_service._async_session_type: + providers.append(ProviderConfig(AsyncEngine, use_value=db_service.engine)) + providers.append( + ProviderConfig( + AsyncSession, + use_value=lambda: db_service.session_factory(), + scope=request_scope, + ) + ) + else: + providers.append(ProviderConfig(sa.Engine, use_value=db_service.engine)) + providers.append( + ProviderConfig( + Session, + use_value=lambda: db_service.session_factory(), + scope=request_scope, + ) + ) + + providers.append(ProviderConfig(EllarSQLAlchemyService, use_value=db_service)) + return DynamicModule( + cls, + providers=providers, + ) + + @classmethod + def register_setup(cls, **override_config: t.Any) -> ModuleSetup: + """ + Register Module to be configured through `SQLALCHEMY_CONFIG` variable in Application Config + """ + root_path = get_main_directory_by_stack("__main__", stack_level=2) + return ModuleSetup( + cls, + inject=[Config], + factory=functools.partial( + cls.__register_setup_factory, + root_path=root_path, + override_config=override_config, + ), + ) + + @staticmethod + def __register_setup_factory( + module: t.Type["EllarSQLAlchemyModule"], + config: Config, + root_path: str, + override_config: t.Dict[str, t.Any], + ) -> DynamicModule: + if config.get("SQLALCHEMY_CONFIG") and isinstance( + config.SQLALCHEMY_CONFIG, dict + ): + defined_config = dict(config.SQLALCHEMY_CONFIG, root_path=root_path) + defined_config.update(override_config) + + schema = SQLAlchemyConfig.model_validate( + defined_config, from_attributes=True + ) + + schema.migration_options.directory = get_main_directory_by_stack( + schema.migration_options.directory, stack_level=2, from_dir=root_path + ) + + return module.__setup_module(schema) + raise RuntimeError("Could not find `SQLALCHEMY_CONFIG` in application config.") diff --git a/ellar_sqlalchemy/py.typed b/ellar_sqlalchemy/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/ellar_sqlalchemy/schemas.py b/ellar_sqlalchemy/schemas.py new file mode 100644 index 0000000..b6752da --- /dev/null +++ b/ellar_sqlalchemy/schemas.py @@ -0,0 +1,36 @@ +import typing as t +from dataclasses import asdict, dataclass, field + +import ellar.common as ecm + +from ellar_sqlalchemy.types import RevisionDirectiveCallable + + +@dataclass +class MigrationOption: + directory: str + revision_directives_callbacks: t.List[RevisionDirectiveCallable] = field( + default_factory=lambda: [] + ) + use_two_phase: bool = False + + def dict(self) -> t.Dict[str, t.Any]: + return asdict(self) + + +class SQLAlchemyConfig(ecm.Serializer): + model_config = {"arbitrary_types_allowed": True} + + databases: t.Union[str, t.Dict[str, t.Any]] + migration_options: MigrationOption + root_path: str + + session_options: t.Dict[str, t.Any] = { + "autoflush": False, + "future": True, + "expire_on_commit": False, + } + echo: bool = False + engine_options: t.Optional[t.Dict[str, t.Any]] = None + + models: t.Optional[t.List[str]] = None diff --git a/ellar_sqlalchemy/services/__init__.py b/ellar_sqlalchemy/services/__init__.py new file mode 100644 index 0000000..964e556 --- /dev/null +++ b/ellar_sqlalchemy/services/__init__.py @@ -0,0 +1,3 @@ +from .base import EllarSQLAlchemyService + +__all__ = ["EllarSQLAlchemyService"] diff --git a/ellar_sqlalchemy/services/base.py b/ellar_sqlalchemy/services/base.py new file mode 100644 index 0000000..dbfffa3 --- /dev/null +++ b/ellar_sqlalchemy/services/base.py @@ -0,0 +1,356 @@ +import os +import typing as t +from threading import get_ident +from weakref import WeakKeyDictionary + +import sqlalchemy as sa +import sqlalchemy.exc as sa_exc +import sqlalchemy.orm as sa_orm +from ellar.common.utils.importer import ( + get_main_directory_by_stack, + import_from_string, + module_import, +) +from ellar.events import app_context_teardown_events +from sqlalchemy.ext.asyncio import ( + AsyncSession, + async_scoped_session, + async_sessionmaker, +) + +from ellar_sqlalchemy.constant import ( + DEFAULT_KEY, + DeclarativeBasePlaceHolder, +) +from ellar_sqlalchemy.model import ( + make_metadata, +) +from ellar_sqlalchemy.model.base import Model +from ellar_sqlalchemy.model.database_binds import get_database_bind, get_database_binds +from ellar_sqlalchemy.schemas import MigrationOption +from ellar_sqlalchemy.session import ModelSession + +from .metadata_engine import MetaDataEngine + + +def _configure_model( + self: "EllarSQLAlchemyService", + models: t.Optional[t.List[str]] = None, +) -> None: + for model in models or []: + module_import(model) + + def model_get_session( + cls: t.Type[Model], + ) -> t.Union[sa_orm.Session, AsyncSession, t.Any]: + return self.session_factory() + + sql_alchemy_declarative_base = import_from_string( + "ellar_sqlalchemy.model.base:SQLAlchemyDefaultBase" + ) + base = ( + Model if sql_alchemy_declarative_base is None else sql_alchemy_declarative_base + ) + + get_db_session = getattr(base, "get_db_session", DeclarativeBasePlaceHolder) + get_db_session_name = get_db_session.__name__ if get_db_session else "" + + if get_db_session_name != "model_get_session": + base.get_db_session = classmethod(model_get_session) + + +class EllarSQLAlchemyService: + def __init__( + self, + databases: t.Union[str, t.Dict[str, t.Any]], + *, + common_session_options: t.Optional[t.Dict[str, t.Any]] = None, + common_engine_options: t.Optional[t.Dict[str, t.Any]] = None, + models: t.Optional[t.List[str]] = None, + echo: bool = False, + root_path: t.Optional[str] = None, + migration_options: t.Optional[MigrationOption] = None, + ) -> None: + self._engines: WeakKeyDictionary[ + "EllarSQLAlchemyService", + t.Dict[str, sa.engine.Engine], + ] = WeakKeyDictionary() + + self._engines.setdefault(self, {}) + self._session_options = common_session_options or {} + + self._common_engine_options = common_engine_options or {} + self._execution_path = get_main_directory_by_stack(root_path, 2) # type:ignore[arg-type] + + self.migration_options = migration_options or MigrationOption( + directory=get_main_directory_by_stack( + self._execution_path or "__main__/migrations", 2 + ) + ) + self._async_session_type: bool = False + + self._setup(databases, models=models, echo=echo) + self.session_factory = self.get_scoped_session() + app_context_teardown_events.connect(self._on_application_tear_down) + + async def _on_application_tear_down(self) -> None: + res = self.session_factory.remove() + if isinstance(res, t.Coroutine): + await res + + @property + def engines(self) -> t.Dict[str, sa.Engine]: + return dict(self._engines[self]) + + @property + def engine(self) -> sa.Engine: + assert self._engines[self].get( + DEFAULT_KEY + ), f"{self.__class__.__name__} configuration is not ready" + return self._engines[self][DEFAULT_KEY] + + def _setup( + self, + databases: t.Union[str, t.Dict[str, t.Any]], + models: t.Optional[t.List[str]] = None, + echo: bool = False, + ) -> None: + _configure_model(self, models) + self._build_engines(databases, echo) + + def _build_engines( + self, databases: t.Union[str, t.Dict[str, t.Any]], echo: bool + ) -> None: + engine_options: t.Dict[str, t.Dict[str, t.Any]] = {} + + if isinstance(databases, str): + common_engine_options = self._common_engine_options.copy() + common_engine_options["url"] = databases + engine_options.setdefault(DEFAULT_KEY, {}).update(common_engine_options) + + elif isinstance(databases, dict): + for key, value in databases.items(): + engine_options[key] = self._common_engine_options.copy() + + if isinstance(value, (str, sa.engine.URL)): + engine_options[key]["url"] = value + else: + engine_options[key].update(value) + else: + raise RuntimeError( + "Invalid databases data structure. Allowed datastructure, str or dict data type" + ) + + if DEFAULT_KEY not in engine_options: + raise RuntimeError( + f"`default` database must be present in databases parameter: {databases}" + ) + + engines = self._engines.setdefault(self, {}) + + for key, options in engine_options.items(): + make_metadata(key) + + options.setdefault("echo", echo) + options.setdefault("echo_pool", echo) + + self._validate_engine_option_defaults(options) + engines[key] = self._make_engine(options) + + found_async_engine = [ + engine for engine in engines.values() if engine.dialect.is_async + ] + if found_async_engine and len(found_async_engine) != len(engines): + raise Exception( + "Databases Configuration must either be all async or all synchronous type" + ) + + self._async_session_type = bool(len(found_async_engine)) + + def __validate_databases_input(self, *databases: str) -> t.Union[str, t.List[str]]: + _databases: t.Union[str, t.List[str]] = list(databases) + if len(_databases) == 0: + _databases = "__all__" + return _databases + + def create_all(self, *databases: str) -> None: + _databases = self.__validate_databases_input(*databases) + + metadata_engines = self._get_metadata_and_engine(_databases) + + if self._async_session_type and _databases == "__all__": + raise Exception( + "You are using asynchronous database configuration. Use `create_all_async` instead" + ) + + for metadata_engine in metadata_engines: + metadata_engine.create_all() + + def drop_all(self, *databases: str) -> None: + _databases = self.__validate_databases_input(*databases) + + metadata_engines = self._get_metadata_and_engine(_databases) + + if self._async_session_type and _databases == "__all__": + raise Exception( + "You are using asynchronous database configuration. Use `drop_all_async` instead" + ) + + for metadata_engine in metadata_engines: + metadata_engine.drop_all() + + def reflect(self, *databases: str) -> None: + _databases = self.__validate_databases_input(*databases) + + metadata_engines = self._get_metadata_and_engine(_databases) + + if self._async_session_type and _databases == "__all__": + raise Exception( + "You are using asynchronous database configuration. Use `reflect_async` instead" + ) + for metadata_engine in metadata_engines: + metadata_engine.reflect() + + async def create_all_async(self, *databases: str) -> None: + _databases = self.__validate_databases_input(*databases) + + metadata_engines = self._get_metadata_and_engine(_databases) + + for metadata_engine in metadata_engines: + if not metadata_engine.is_async(): + metadata_engine.create_all() + continue + await metadata_engine.create_all_async() + + async def drop_all_async(self, *databases: str) -> None: + _databases = self.__validate_databases_input(*databases) + + metadata_engines = self._get_metadata_and_engine(_databases) + + for metadata_engine in metadata_engines: + if not metadata_engine.is_async(): + metadata_engine.drop_all() + continue + await metadata_engine.drop_all_async() + + async def reflect_async(self, *databases: str) -> None: + _databases = self.__validate_databases_input(*databases) + + metadata_engines = self._get_metadata_and_engine(_databases) + + for metadata_engine in metadata_engines: + if not metadata_engine.is_async(): + metadata_engine.reflect() + continue + await metadata_engine.reflect_async() + + def get_scoped_session( + self, + **extra_options: t.Any, + ) -> t.Union[ + sa_orm.scoped_session[sa_orm.Session], + async_scoped_session[t.Union[AsyncSession, t.Any]], + ]: + options = self._session_options.copy() + options.update(extra_options) + + scope = options.pop("scopefunc", get_ident) + + factory = self._make_session_factory(options) + + if self._async_session_type: + return async_scoped_session(factory, scope) # type:ignore[arg-type] + + return sa_orm.scoped_session(factory, scope) # type:ignore[arg-type] + + def _make_session_factory( + self, options: t.Dict[str, t.Any] + ) -> t.Union[sa_orm.sessionmaker[sa_orm.Session], async_sessionmaker[AsyncSession]]: + if self._async_session_type: + options.setdefault("sync_session_class", ModelSession) + else: + options.setdefault("class_", ModelSession) + + session_class = options.get("class_", options.get("sync_session_class")) + + if session_class is ModelSession or issubclass(session_class, ModelSession): + options.update(engines=self._engines[self]) + + if self._async_session_type: + return async_sessionmaker(**options) + + return sa_orm.sessionmaker(**options) + + def _validate_engine_option_defaults(self, options: t.Dict[str, t.Any]) -> None: + url = sa.engine.make_url(options["url"]) + + if url.drivername in {"sqlite", "sqlite+pysqlite", "sqlite+aiosqlite"}: + if url.database is None or url.database in {"", ":memory:"}: + options["poolclass"] = sa.pool.StaticPool + + if "connect_args" not in options: + options["connect_args"] = {} + + options["connect_args"]["check_same_thread"] = False + + elif self._execution_path: + is_uri = url.query.get("uri", False) + + if is_uri: + db_str = url.database[5:] + else: + db_str = url.database + + if not os.path.isabs(db_str): + root_path = os.path.join(self._execution_path, "sqlite") + os.makedirs(root_path, exist_ok=True) + db_str = os.path.join(root_path, db_str) + + if is_uri: + db_str = f"file:{db_str}" + + options["url"] = url.set(database=db_str) + + elif url.drivername.startswith("mysql"): + # set queue defaults only when using queue pool + if ( + "pool_class" not in options + or options["pool_class"] is sa.pool.QueuePool + ): + options.setdefault("pool_recycle", 7200) + + if "charset" not in url.query: + options["url"] = url.update_query_dict({"charset": "utf8mb4"}) + + def _make_engine(self, options: t.Dict[str, t.Any]) -> sa.engine.Engine: + engine = sa.engine_from_config(options, prefix="") + + # if engine.dialect.is_async: + # return AsyncEngine(engine) + + return engine + + def _get_metadata_and_engine( + self, database: t.Union[str, t.List[str]] = "__all__" + ) -> t.List[MetaDataEngine]: + engines = self._engines[self] + + if database == "__all__": + keys: t.List[str] = list(get_database_binds()) + elif isinstance(database, str): + keys = [database] + else: + keys = database + + result: t.List[MetaDataEngine] = [] + + for key in keys: + try: + engine = engines[key] + except KeyError: + message = f"Bind key '{key}' is not in 'Database' config." + raise sa_exc.UnboundExecutionError(message) from None + + metadata = get_database_bind(key, certain=True) + result.append(MetaDataEngine(metadata=metadata, engine=engine)) + return result diff --git a/ellar_sqlalchemy/services/metadata_engine.py b/ellar_sqlalchemy/services/metadata_engine.py new file mode 100644 index 0000000..1d3dec3 --- /dev/null +++ b/ellar_sqlalchemy/services/metadata_engine.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import dataclasses + +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncEngine + + +@dataclasses.dataclass +class MetaDataEngine: + metadata: sa.MetaData + engine: sa.Engine + + def is_async(self) -> bool: + return self.engine.dialect.is_async + + def create_all(self) -> None: + self.metadata.create_all(bind=self.engine) + + async def create_all_async(self) -> None: + engine = AsyncEngine(self.engine) + async with engine.begin() as conn: + await conn.run_sync(self.metadata.create_all) + + def drop_all(self) -> None: + self.metadata.drop_all(bind=self.engine) + + async def drop_all_async(self) -> None: + engine = AsyncEngine(self.engine) + async with engine.begin() as conn: + await conn.run_sync(self.metadata.drop_all) + + def reflect(self) -> None: + self.metadata.reflect(bind=self.engine) + + async def reflect_async(self) -> None: + engine = AsyncEngine(self.engine) + async with engine.begin() as conn: + await conn.run_sync(self.metadata.reflect) diff --git a/ellar_sqlalchemy/session.py b/ellar_sqlalchemy/session.py new file mode 100644 index 0000000..d163c84 --- /dev/null +++ b/ellar_sqlalchemy/session.py @@ -0,0 +1,78 @@ +import typing as t + +import sqlalchemy as sa +import sqlalchemy.exc as sa_exc +import sqlalchemy.orm as sa_orm + +from ellar_sqlalchemy.constant import DEFAULT_KEY + +EngineType = t.Optional[t.Union[sa.engine.Engine, sa.engine.Connection]] + + +def _get_engine_from_clause( + clause: t.Optional[sa.ClauseElement], + engines: t.Mapping[str, sa.Engine], +) -> t.Optional[sa.Engine]: + table = None + + if clause is not None: + if isinstance(clause, sa.Table): + table = clause + elif isinstance(clause, sa.UpdateBase) and isinstance(clause.table, sa.Table): + table = clause.table + + if table is not None and "database_bind_key" in table.metadata.info: + key = table.metadata.info["database_bind_key"] + + if key not in engines: + raise sa_exc.UnboundExecutionError( + f"Database Bind key '{key}' is not in 'Database' config." + ) + + return engines[key] + + return None + + +class ModelSession(sa_orm.Session): + def __init__(self, engines: t.Mapping[str, sa.Engine], **kwargs: t.Any) -> None: + super().__init__(**kwargs) + self._engines = engines + self._model_changes: t.Dict[object, t.Tuple[t.Any, str]] = {} + + def get_bind( # type:ignore[override] + self, + mapper: t.Optional[t.Any] = None, + clause: t.Optional[t.Any] = None, + bind: EngineType = None, + **kwargs: t.Any, + ) -> EngineType: + if bind is not None: + return bind + + engines = self._engines + + if mapper is not None: + try: + mapper = sa.inspect(mapper) + except sa_exc.NoInspectionAvailable as e: + if isinstance(mapper, type): + raise sa_orm.exc.UnmappedClassError(mapper) from e + + raise + + engine = _get_engine_from_clause(mapper.local_table, engines) + + if engine is not None: + return engine + + if clause is not None: + engine = _get_engine_from_clause(clause, engines) + + if engine is not None: + return engine + + if DEFAULT_KEY in engines: + return engines[DEFAULT_KEY] + + return super().get_bind(mapper=mapper, clause=clause, bind=bind, **kwargs) diff --git a/ellar_sqlalchemy/templates/basic/README b/ellar_sqlalchemy/templates/basic/README new file mode 100644 index 0000000..eaae251 --- /dev/null +++ b/ellar_sqlalchemy/templates/basic/README @@ -0,0 +1 @@ +Multi-database configuration for Flask. diff --git a/ellar_sqlalchemy/templates/basic/alembic.ini.mako b/ellar_sqlalchemy/templates/basic/alembic.ini.mako new file mode 100644 index 0000000..fc59da6 --- /dev/null +++ b/ellar_sqlalchemy/templates/basic/alembic.ini.mako @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/ellar_sqlalchemy/templates/basic/env.py b/ellar_sqlalchemy/templates/basic/env.py new file mode 100644 index 0000000..4edeef8 --- /dev/null +++ b/ellar_sqlalchemy/templates/basic/env.py @@ -0,0 +1,48 @@ +import asyncio +import typing as t +from logging.config import fileConfig + +from alembic import context +from ellar.app import current_injector + +from ellar_sqlalchemy.migrations import ( + MultipleDatabaseAlembicEnvMigration, + SingleDatabaseAlembicEnvMigration, +) +from ellar_sqlalchemy.services import EllarSQLAlchemyService + +db_service: EllarSQLAlchemyService = current_injector.get(EllarSQLAlchemyService) + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) # type:ignore[arg-type] +# logger = logging.getLogger("alembic.env") + + +AlembicEnvMigrationKlass: t.Type[ + t.Union[MultipleDatabaseAlembicEnvMigration, SingleDatabaseAlembicEnvMigration] +] = ( + MultipleDatabaseAlembicEnvMigration + if len(db_service.engines) > 1 + else SingleDatabaseAlembicEnvMigration +) + + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +alembic_env_migration = AlembicEnvMigrationKlass(db_service) + +if context.is_offline_mode(): + alembic_env_migration.run_migrations_offline(context) # type:ignore[arg-type] +else: + asyncio.get_event_loop().run_until_complete( + alembic_env_migration.run_migrations_online(context) # type:ignore[arg-type] + ) diff --git a/ellar_sqlalchemy/templates/basic/script.py.mako b/ellar_sqlalchemy/templates/basic/script.py.mako new file mode 100644 index 0000000..4b7a50f --- /dev/null +++ b/ellar_sqlalchemy/templates/basic/script.py.mako @@ -0,0 +1,63 @@ +<%! +import re + +%>"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + +<%! + from ellar.app import current_injector + from ellar_sqlalchemy.services import EllarSQLAlchemyService + + db_service = current_injector.get(EllarSQLAlchemyService) + db_names = list(db_service.engines.keys()) +%> + +% if len(db_names) > 1: + +def upgrade(engine_name): + globals()["upgrade_%s" % engine_name]() + + +def downgrade(engine_name): + globals()["downgrade_%s" % engine_name]() + + + +## generate an "upgrade_() / downgrade_()" function +## for each database name in the ini file. + +% for db_name in db_names: + +def upgrade_${db_name}(): + ${context.get("%s_upgrades" % db_name, "pass")} + + +def downgrade_${db_name}(): + ${context.get("%s_downgrades" % db_name, "pass")} + +% endfor + +% else: + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} + +% endif diff --git a/ellar_sqlalchemy/types.py b/ellar_sqlalchemy/types.py new file mode 100644 index 0000000..79e7775 --- /dev/null +++ b/ellar_sqlalchemy/types.py @@ -0,0 +1,15 @@ +import typing as t + +if t.TYPE_CHECKING: + from alembic.operations import MigrationScript + from alembic.runtime.migration import MigrationContext + +RevisionArgs = t.Union[ + str, + t.Iterable[t.Optional[str]], + t.Iterable[str], +] + +RevisionDirectiveCallable = t.Callable[ + ["MigrationContext", RevisionArgs, t.List["MigrationScript"]], None +] diff --git a/examples/single-db/README.md b/examples/single-db/README.md new file mode 100644 index 0000000..4e9d856 --- /dev/null +++ b/examples/single-db/README.md @@ -0,0 +1,26 @@ +## Ellar SQLAlchemy Single Database Example +Project Description + +## Requirements +Python >= 3.7 +Starlette +Injector + +## Project setup +```shell +pip install poetry +``` +then, +```shell +poetry install +``` +### Apply Migration +```shell +ellar db upgrade +``` + +### Development Server +```shell +ellar runserver --reload +``` +then, visit [http://127.0.0.1:8000/docs](http://127.0.0.1:8000/docs) \ No newline at end of file diff --git a/examples/single-db/db/__init__.py b/examples/single-db/db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/single-db/db/controllers.py b/examples/single-db/db/controllers.py new file mode 100644 index 0000000..0632fb8 --- /dev/null +++ b/examples/single-db/db/controllers.py @@ -0,0 +1,47 @@ +""" +Define endpoints routes in python class-based fashion +example: + +@Controller("/dogs", tag="Dogs", description="Dogs Resources") +class MyController(ControllerBase): + @get('/') + def index(self): + return {'detail': "Welcome Dog's Resources"} +""" +from ellar.common import Controller, ControllerBase, get, post, Body +from pydantic import EmailStr +from sqlalchemy import select + +from .models.users import User + + +@Controller +class DbController(ControllerBase): + + @get("/") + def index(self): + return {"detail": "Welcome Db Resource"} + + @post("/users") + def create_user(self, username: Body[str], email: Body[EmailStr]): + session = User.get_db_session() + user = User(username=username, email=email) + + session.add(user) + session.commit() + + return user.dict() + + @get("/users/{user_id:int}") + def get_user_by_id(self, user_id: int): + session = User.get_db_session() + stmt = select(User).filter(User.id == user_id) + user = session.execute(stmt).scalar() + return user.dict() + + @get("/users") + async def get_all_users(self): + session = User.get_db_session() + stmt = select(User) + rows = session.execute(stmt.offset(0).limit(100)).scalars() + return [row.dict() for row in rows] diff --git a/examples/single-db/db/migrations/README b/examples/single-db/db/migrations/README new file mode 100644 index 0000000..eaae251 --- /dev/null +++ b/examples/single-db/db/migrations/README @@ -0,0 +1 @@ +Multi-database configuration for Flask. diff --git a/examples/single-db/db/migrations/alembic.ini b/examples/single-db/db/migrations/alembic.ini new file mode 100644 index 0000000..fc59da6 --- /dev/null +++ b/examples/single-db/db/migrations/alembic.ini @@ -0,0 +1,50 @@ +# A generic, single database configuration. + +[alembic] +# template used to generate migration files +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic,flask_migrate + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[logger_flask_migrate] +level = INFO +handlers = +qualname = flask_migrate + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/examples/single-db/db/migrations/env.py b/examples/single-db/db/migrations/env.py new file mode 100644 index 0000000..4edeef8 --- /dev/null +++ b/examples/single-db/db/migrations/env.py @@ -0,0 +1,48 @@ +import asyncio +import typing as t +from logging.config import fileConfig + +from alembic import context +from ellar.app import current_injector + +from ellar_sqlalchemy.migrations import ( + MultipleDatabaseAlembicEnvMigration, + SingleDatabaseAlembicEnvMigration, +) +from ellar_sqlalchemy.services import EllarSQLAlchemyService + +db_service: EllarSQLAlchemyService = current_injector.get(EllarSQLAlchemyService) + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +fileConfig(config.config_file_name) # type:ignore[arg-type] +# logger = logging.getLogger("alembic.env") + + +AlembicEnvMigrationKlass: t.Type[ + t.Union[MultipleDatabaseAlembicEnvMigration, SingleDatabaseAlembicEnvMigration] +] = ( + MultipleDatabaseAlembicEnvMigration + if len(db_service.engines) > 1 + else SingleDatabaseAlembicEnvMigration +) + + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +alembic_env_migration = AlembicEnvMigrationKlass(db_service) + +if context.is_offline_mode(): + alembic_env_migration.run_migrations_offline(context) # type:ignore[arg-type] +else: + asyncio.get_event_loop().run_until_complete( + alembic_env_migration.run_migrations_online(context) # type:ignore[arg-type] + ) diff --git a/examples/single-db/db/migrations/script.py.mako b/examples/single-db/db/migrations/script.py.mako new file mode 100644 index 0000000..4b7a50f --- /dev/null +++ b/examples/single-db/db/migrations/script.py.mako @@ -0,0 +1,63 @@ +<%! +import re + +%>"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision = ${repr(up_revision)} +down_revision = ${repr(down_revision)} +branch_labels = ${repr(branch_labels)} +depends_on = ${repr(depends_on)} + +<%! + from ellar.app import current_injector + from ellar_sqlalchemy.services import EllarSQLAlchemyService + + db_service = current_injector.get(EllarSQLAlchemyService) + db_names = list(db_service.engines.keys()) +%> + +% if len(db_names) > 1: + +def upgrade(engine_name): + globals()["upgrade_%s" % engine_name]() + + +def downgrade(engine_name): + globals()["downgrade_%s" % engine_name]() + + + +## generate an "upgrade_() / downgrade_()" function +## for each database name in the ini file. + +% for db_name in db_names: + +def upgrade_${db_name}(): + ${context.get("%s_upgrades" % db_name, "pass")} + + +def downgrade_${db_name}(): + ${context.get("%s_downgrades" % db_name, "pass")} + +% endfor + +% else: + +def upgrade(): + ${upgrades if upgrades else "pass"} + + +def downgrade(): + ${downgrades if downgrades else "pass"} + +% endif diff --git a/examples/single-db/db/migrations/versions/2023_12_30_2053-b7712f83d45b_first_migration.py b/examples/single-db/db/migrations/versions/2023_12_30_2053-b7712f83d45b_first_migration.py new file mode 100644 index 0000000..562feb5 --- /dev/null +++ b/examples/single-db/db/migrations/versions/2023_12_30_2053-b7712f83d45b_first_migration.py @@ -0,0 +1,39 @@ +"""first migration + +Revision ID: b7712f83d45b +Revises: +Create Date: 2023-12-30 20:53:37.393009 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'b7712f83d45b' +down_revision = None +branch_labels = None +depends_on = None + + + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('user', + sa.Column('id', sa.Integer(), nullable=False), + sa.Column('username', sa.String(), nullable=False), + sa.Column('email', sa.String(), nullable=False), + sa.Column('created_date', sa.DateTime(), nullable=False), + sa.Column('time_updated', sa.DateTime(), nullable=False), + sa.PrimaryKeyConstraint('id', name=op.f('pk_user')), + sa.UniqueConstraint('username', name=op.f('uq_user_username')) + ) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('user') + # ### end Alembic commands ### + diff --git a/examples/single-db/db/models/__init__.py b/examples/single-db/db/models/__init__.py new file mode 100644 index 0000000..67c9337 --- /dev/null +++ b/examples/single-db/db/models/__init__.py @@ -0,0 +1,5 @@ +from .users import User + +__all__ = [ + 'User', +] diff --git a/examples/single-db/db/models/base.py b/examples/single-db/db/models/base.py new file mode 100644 index 0000000..0f35189 --- /dev/null +++ b/examples/single-db/db/models/base.py @@ -0,0 +1,27 @@ +from datetime import datetime +from sqlalchemy import DateTime, func, MetaData +from sqlalchemy.orm import Mapped, mapped_column + +from ellar_sqlalchemy.model import Model + +convention = { + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +} + + +class Base(Model, as_base=True): + __database__ = 'default' + + metadata = MetaData(naming_convention=convention) + + created_date: Mapped[datetime] = mapped_column( + "created_date", DateTime, default=datetime.utcnow, nullable=False + ) + + time_updated: Mapped[datetime] = mapped_column( + "time_updated", DateTime, nullable=False, default=datetime.utcnow, onupdate=func.now() + ) diff --git a/examples/single-db/db/models/users.py b/examples/single-db/db/models/users.py new file mode 100644 index 0000000..489a36d --- /dev/null +++ b/examples/single-db/db/models/users.py @@ -0,0 +1,18 @@ + +from sqlalchemy import Integer, String +from sqlalchemy.orm import Mapped, mapped_column +from .base import Base + +class User(Base): + id: Mapped[int] = mapped_column(Integer, primary_key=True) + username: Mapped[str] = mapped_column(String, unique=True, nullable=False) + email: Mapped[str] = mapped_column(String) + + + +assert getattr(User, '__dnd__', None) == 'Ellar' + +# assert session + + + diff --git a/examples/single-db/db/module.py b/examples/single-db/db/module.py new file mode 100644 index 0000000..7ff4082 --- /dev/null +++ b/examples/single-db/db/module.py @@ -0,0 +1,56 @@ +""" +@Module( + controllers=[MyController], + providers=[ + YourService, + ProviderConfig(IService, use_class=AService), + ProviderConfig(IFoo, use_value=FooService()), + ], + routers=(routerA, routerB) + statics='statics', + template='template_folder', + # base_directory -> default is the `db` folder +) +class MyModule(ModuleBase): + def register_providers(self, container: Container) -> None: + # for more complicated provider registrations + pass + +""" +from ellar.app import App +from ellar.common import Module, IApplicationStartup +from ellar.core import ModuleBase +from ellar.di import Container +from ellar_sqlalchemy import EllarSQLAlchemyModule, EllarSQLAlchemyService + +from .controllers import DbController + + +@Module( + controllers=[DbController], + providers=[], + routers=[], + modules=[ + EllarSQLAlchemyModule.setup( + databases={ + 'default': 'sqlite:///project.db', + }, + echo=True, + migration_options={ + 'directory': '__main__/migrations' + }, + models=['db.models'] + ) + ] +) +class DbModule(ModuleBase, IApplicationStartup): + """ + Db Module + """ + + async def on_startup(self, app: App) -> None: + db_service = app.injector.get(EllarSQLAlchemyService) + # db_service.create_all() + + def register_providers(self, container: Container) -> None: + """for more complicated provider registrations, use container.register_instance(...) """ \ No newline at end of file diff --git a/examples/single-db/db/sqlite/project.db b/examples/single-db/db/sqlite/project.db new file mode 100644 index 0000000..77a0830 Binary files /dev/null and b/examples/single-db/db/sqlite/project.db differ diff --git a/examples/single-db/db/tests/__init__.py b/examples/single-db/db/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/single-db/pyproject.toml b/examples/single-db/pyproject.toml new file mode 100644 index 0000000..c9bb429 --- /dev/null +++ b/examples/single-db/pyproject.toml @@ -0,0 +1,26 @@ +[tool.poetry] +name = "single-db" +version = "0.1.0" +description = "Demonstrating SQLAlchemy with Ellar" +authors = ["Ezeudoh Tochukwu "] +license = "MIT" +readme = "README.md" +packages = [{include = "single_db"}] + +[tool.poetry.dependencies] +python = "^3.8" +ellar-cli = "^0.2.6" +ellar = "^0.6.2" + + +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[ellar] +default = "single_db" +[ellar.projects.single_db] +project-name = "single_db" +application = "single_db.server:application" +config = "single_db.config:DevelopmentConfig" +root-module = "single_db.root_module:ApplicationModule" diff --git a/examples/single-db/single_db/__init__.py b/examples/single-db/single_db/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/single-db/single_db/config.py b/examples/single-db/single_db/config.py new file mode 100644 index 0000000..d393f88 --- /dev/null +++ b/examples/single-db/single_db/config.py @@ -0,0 +1,82 @@ +""" +Application Configurations +Default Ellar Configurations are exposed here through `ConfigDefaultTypesMixin` +Make changes and define your own configurations specific to your application + +export ELLAR_CONFIG_MODULE=ellar_sqlachemy_example.config:DevelopmentConfig +""" + +import typing as t + +from ellar.pydantic import ENCODERS_BY_TYPE as encoders_by_type +from starlette.middleware import Middleware +from ellar.common import IExceptionHandler, JSONResponse +from ellar.core import ConfigDefaultTypesMixin +from ellar.core.versioning import BaseAPIVersioning, DefaultAPIVersioning + + +class BaseConfig(ConfigDefaultTypesMixin): + DEBUG: bool = False + + DEFAULT_JSON_CLASS: t.Type[JSONResponse] = JSONResponse + SECRET_KEY: str = "ellar_wltsSVEySCVC3xC2i4a0y40jlbcjTupkCX0TSoUT-R4" + + # injector auto_bind = True allows you to resolve types that are not registered on the container + # For more info, read: https://injector.readthedocs.io/en/latest/index.html + INJECTOR_AUTO_BIND = False + + # jinja Environment options + # https://jinja.palletsprojects.com/en/3.0.x/api/#high-level-api + JINJA_TEMPLATES_OPTIONS: t.Dict[str, t.Any] = {} + + # Application route versioning scheme + VERSIONING_SCHEME: BaseAPIVersioning = DefaultAPIVersioning() + + # Enable or Disable Application Router route searching by appending backslash + REDIRECT_SLASHES: bool = False + + # Define references to static folders in python packages. + # eg STATIC_FOLDER_PACKAGES = [('boostrap4', 'statics')] + STATIC_FOLDER_PACKAGES: t.Optional[t.List[t.Union[str, t.Tuple[str, str]]]] = [] + + # Define references to static folders defined within the project + STATIC_DIRECTORIES: t.Optional[t.List[t.Union[str, t.Any]]] = [] + + # static route path + STATIC_MOUNT_PATH: str = "/static" + + CORS_ALLOW_ORIGINS: t.List[str] = ["*"] + CORS_ALLOW_METHODS: t.List[str] = ["*"] + CORS_ALLOW_HEADERS: t.List[str] = ["*"] + ALLOWED_HOSTS: t.List[str] = ["*"] + + # Application middlewares + MIDDLEWARE: t.Sequence[Middleware] = [] + + # A dictionary mapping either integer status codes, + # or exception class types onto callables which handle the exceptions. + # Exception handler callables should be of the form + # `handler(context:IExecutionContext, exc: Exception) -> response` + # and may be either standard functions, or async functions. + EXCEPTION_HANDLERS: t.List[IExceptionHandler] = [] + + # Object Serializer custom encoders + SERIALIZER_CUSTOM_ENCODER: t.Dict[ + t.Any, t.Callable[[t.Any], t.Any] + ] = encoders_by_type + + +class DevelopmentConfig(BaseConfig): + DEBUG: bool = True + # Configuration through Confog + SQLALCHEMY_CONFIG: t.Dict[str, t.Any] = { + 'databases': { + 'default': 'sqlite+aiosqlite:///project.db', + # 'db2': 'sqlite+aiosqlite:///project2.db', + }, + 'echo': True, + 'migration_options': { + 'directory': '__main__/migrations' + }, + 'models': ['db.models'] + } \ No newline at end of file diff --git a/examples/single-db/single_db/root_module.py b/examples/single-db/single_db/root_module.py new file mode 100644 index 0000000..710e4a2 --- /dev/null +++ b/examples/single-db/single_db/root_module.py @@ -0,0 +1,10 @@ +from ellar.common import Module, exception_handler, IExecutionContext, JSONResponse, Response +from ellar.core import ModuleBase, LazyModuleImport as lazyLoad +from ellar.samples.modules import HomeModule + + +@Module(modules=[HomeModule, lazyLoad('db.module:DbModule')]) +class ApplicationModule(ModuleBase): + @exception_handler(404) + def exception_404_handler(cls, ctx: IExecutionContext, exc: Exception) -> Response: + return JSONResponse(dict(detail="Resource not found."), status_code=404) \ No newline at end of file diff --git a/examples/single-db/single_db/server.py b/examples/single-db/single_db/server.py new file mode 100644 index 0000000..4b06927 --- /dev/null +++ b/examples/single-db/single_db/server.py @@ -0,0 +1,28 @@ +import os + +from ellar.app import AppFactory +from ellar.common.constants import ELLAR_CONFIG_MODULE +from ellar.core import LazyModuleImport as lazyLoad +from ellar.openapi import OpenAPIDocumentModule, OpenAPIDocumentBuilder, SwaggerUI + +application = AppFactory.create_from_app_module( + lazyLoad("single_db.root_module:ApplicationModule"), + config_module=os.environ.get( + ELLAR_CONFIG_MODULE, "single_db.config:DevelopmentConfig" + ), + global_guards=[] +) + +document_builder = OpenAPIDocumentBuilder() +document_builder.set_title('Ellar Sqlalchemy Single Database Example') \ + .set_version('1.0.2') \ + .set_contact(name='Author Name', url='https://www.author-name.com', email='authorname@gmail.com') \ + .set_license('MIT Licence', url='https://www.google.com') + +document = document_builder.build_document(application) +module = OpenAPIDocumentModule.setup( + document=document, + docs_ui=SwaggerUI(dark_theme=True), + guards=[] +) +application.install_module(module) \ No newline at end of file diff --git a/examples/single-db/tests/conftest.py b/examples/single-db/tests/conftest.py new file mode 100644 index 0000000..b18f954 --- /dev/null +++ b/examples/single-db/tests/conftest.py @@ -0,0 +1 @@ +from ellar.testing import Test, TestClient \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..bddfc1a --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,94 @@ +[build-system] +requires = ["flit_core >=2,<4"] +build-backend = "flit_core.buildapi" + +[tool.flit.module] +name = "ellar_sqlalchemy" + + +[project] +name = "ellar-sqlalchemy" +authors = [ + {name = "Ezeudoh Tochukwu", email = "tochukwu.ezeudoh@gmail.com"}, +] +dynamic = ["version", "description"] +requires-python = ">=3.8" +readme = "README.md" +home-page = "https://github.com/python-ellar/ellar-sqlalchemy" +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Topic :: Internet :: WWW/HTTP :: Dynamic Content", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", +] + +dependencies = [ + "ellar >= 0.6.2", + "sqlalchemy >=2.0.16", + "alembic >= 1.10.0", +] + +dev = [ + "pre-commit" +] + +[project.urls] +Documentation = "https://github.com/python-ellar/ellar-sqlalchemy" +Source = "https://github.com/python-ellar/ellar-sqlalchemy" +Homepage = "https://python-ellar.github.io/ellar-sqlalchemy/" +"Bug Tracker" = "https://github.com/python-ellar/ellar-sqlalchemy/issues" + +[project.optional-dependencies] +test = [ + "pytest >= 7.1.3,<8.0.0", + "pytest-cov >= 2.12.0,<5.0.0", + "ruff ==0.1.7", + "mypy == 1.7.1", + "autoflake", +] + +[tool.ruff] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "C", # flake8-comprehensions + "B", # flake8-bugbear +] +ignore = [ + "E501", # line too long, handled by black + "B008", # do not perform function calls in argument defaults + "C901", # too complex +] + +[tool.ruff.per-file-ignores] +"__init__.py" = ["F401"] + +[tool.ruff.isort] +known-third-party = ["ellar"] + +[tool.mypy] +python_version = "3.8" +show_error_codes = true +pretty = true +strict = true +# db.Model attribute doesn't recognize subclassing +disable_error_code = ["name-defined", 'union-attr'] +# db.Model is Any +disallow_subclassing_any = false +[[tool.mypy.overrides]] +module = "ellar_sqlalchemy.cli.commands" +ignore_errors = true +[[tool.mypy.overrides]] +module = "ellar_sqlalchemy.migrations.*" +disable_error_code = ["arg-type", 'union-attr'] \ No newline at end of file diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..b1f9d2f --- /dev/null +++ b/pytest.ini @@ -0,0 +1,8 @@ +[pytest] +addopts = --strict-config --strict-markers +xfail_strict = true +junit_family = "xunit2" +norecursedirs = examples/* + +[pytest-watch] +runner= pytest --failed-first --maxfail=1 --no-success-flaky-report diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e69de29