From 29db7acc542026f3046d9ca96df55956129f0212 Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Fri, 30 Aug 2024 17:23:06 -0600 Subject: [PATCH 1/8] Rename AsyncSQLAlchemyMiddleware to AsyncSQLModelMiddleware Signed-off-by: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> --- README.md | 2 +- src/fastapi_async_sql/middlewares.py | 16 +++++++++++++++- tests/conftest.py | 4 ++-- tests/test_middlewares.py | 22 ++++++++++------------ 4 files changed, 28 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 03f18eb..caf9404 100644 --- a/README.md +++ b/README.md @@ -12,7 +12,7 @@ uv pip install fastapi-async-sql ## Features -- AsyncSQLAlchemyMiddleware: A middleware to handle database connections with AsyncSQLAlchemy +- AsyncSQLModelMiddleware: A middleware to handle database connections with AsyncSQLAlchemy - [SQLModel](https://sqlmodel.tiangolo.com/): A library to handle database models with Pydantic and SQLAlchemy - Base models for `SQLModel`: - `BaseSQLModel`: A opinionated base model for SQLAlchemy models diff --git a/src/fastapi_async_sql/middlewares.py b/src/fastapi_async_sql/middlewares.py index b558e91..ba66ee6 100644 --- a/src/fastapi_async_sql/middlewares.py +++ b/src/fastapi_async_sql/middlewares.py @@ -12,7 +12,21 @@ ) -class AsyncSQLAlchemyMiddleware(BaseHTTPMiddleware): +class AsyncSQLModelMiddleware(BaseHTTPMiddleware): + """Middleware to handle the database session. + + /// info | Usage Documentation + [Middlewares](../concepts/middlewares.md#asyncsqlmodelmiddleware) + /// + + Attributes: + app (ASGIApp): The ASGI app. + db_url (str | None): The database URL. Defaults to None. + custom_engine (AsyncEngine | None): The custom engine. Defaults to None. + session_options (dict | None): The session options. Defaults to None. + engine_options (dict | None): The engine options. Defaults to None. + """ + def __init__( self, app: ASGIApp, diff --git a/tests/conftest.py b/tests/conftest.py index d891b09..314198e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine from sqlmodel.ext.asyncio.session import AsyncSession -from fastapi_async_sql.middlewares import AsyncSQLAlchemyMiddleware +from fastapi_async_sql.middlewares import AsyncSQLModelMiddleware from fastapi_async_sql.models import BaseSQLModel from tests.factories import HeroFactory, ItemFactory, TeamFactory, register_factories @@ -64,7 +64,7 @@ def app(engine: AsyncEngine) -> FastAPI: """Create the FastAPI app.""" app = FastAPI() app.add_middleware( - AsyncSQLAlchemyMiddleware, # noqa + AsyncSQLModelMiddleware, # noqa custom_engine=engine, ) return app diff --git a/tests/test_middlewares.py b/tests/test_middlewares.py index 74128cd..ab2e50d 100644 --- a/tests/test_middlewares.py +++ b/tests/test_middlewares.py @@ -12,7 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware from fastapi_async_sql.exceptions import MissingArgsError, MultipleArgsError -from fastapi_async_sql.middlewares import AsyncSQLAlchemyMiddleware +from fastapi_async_sql.middlewares import AsyncSQLModelMiddleware from tests.models.hero_model import Hero from tests.models.item_model import Item @@ -28,8 +28,8 @@ def app(): @pytest.fixture(scope="function") def app_with_db_middleware(app, engine: AsyncEngine): - """Create a FastAPI app with the AsyncSQLAlchemyMiddleware.""" - app.add_middleware(AsyncSQLAlchemyMiddleware, custom_engine=engine) # noqa + """Create a FastAPI app with the AsyncSQLModelMiddleware.""" + app.add_middleware(AsyncSQLModelMiddleware, custom_engine=engine) # noqa return app @@ -44,7 +44,7 @@ async def client(app: FastAPI, test_server_url: str) -> httpx.AsyncClient: async def test_init_async_sqlalchemy_middleware(app: FastAPI, database_url: str): """Test that the middleware is correctly initialised.""" - mw = AsyncSQLAlchemyMiddleware(app, db_url=database_url) + mw = AsyncSQLModelMiddleware(app, db_url=database_url) assert isinstance(mw, BaseHTTPMiddleware) @@ -52,14 +52,14 @@ async def test_init_async_sqlalchemy_middleware_custom_engine( app: FastAPI, engine: AsyncEngine ): """Test that the middleware is correctly initialised with a custom engine.""" - mw = AsyncSQLAlchemyMiddleware(app, custom_engine=engine) + mw = AsyncSQLModelMiddleware(app, custom_engine=engine) assert isinstance(mw, BaseHTTPMiddleware) async def test_init_async_sqlalchemy_middleware_missing_required_args(app: FastAPI): """Test that the middleware raises an error if no db_url or custom_engine is passed.""" with pytest.raises(MissingArgsError) as exc: - AsyncSQLAlchemyMiddleware(app) + AsyncSQLModelMiddleware(app) assert str(exc.value) == "You need to pass db_url or custom_engine parameter." @@ -68,7 +68,7 @@ async def test_init_async_sqlalchemy_middleware_multiple_args( ): """Test that the middleware raises an error if both db_url and custom_engine are passed.""" with pytest.raises(MultipleArgsError) as exc: - AsyncSQLAlchemyMiddleware(app, db_url=database_url, custom_engine=engine) + AsyncSQLModelMiddleware(app, db_url=database_url, custom_engine=engine) assert str(exc.value) == "Mutually exclusive parameters: db_url, custom_engine." @@ -79,7 +79,7 @@ async def test_init_async_sqlalchemy_middleware_correct_optional_args( engine_options = {"echo": True, "poolclass": NullPool} session_options = {"autoflush": False} - mw = AsyncSQLAlchemyMiddleware( + mw = AsyncSQLModelMiddleware( app, database_url, engine_options=engine_options, @@ -99,12 +99,10 @@ async def test_init_async_sqlalchemy_middleware_incorrect_optional_args( ): """Test that the middleware is correctly initialised with incorrect optional arguments.""" with pytest.raises(TypeError) as exc: - AsyncSQLAlchemyMiddleware( - app, db_url="sqlite+aiosqlite://", invalid_args="test" - ) + AsyncSQLModelMiddleware(app, db_url="sqlite+aiosqlite://", invalid_args="test") assert ( str(exc.value) - == "AsyncSQLAlchemyMiddleware.__init__() got an unexpected keyword argument 'invalid_args'" + == "AsyncSQLModelMiddleware.__init__() got an unexpected keyword argument 'invalid_args'" ) From 196112920b890e3d803b165b9a44e02391f99a0b Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Fri, 30 Aug 2024 17:25:22 -0600 Subject: [PATCH 2/8] Fix PK type on BaseRepository Signed-off-by: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> --- src/fastapi_async_sql/repositories.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/fastapi_async_sql/repositories.py b/src/fastapi_async_sql/repositories.py index d032d13..26eec7a 100644 --- a/src/fastapi_async_sql/repositories.py +++ b/src/fastapi_async_sql/repositories.py @@ -3,7 +3,7 @@ from collections.abc import Sequence from typing import Any, Generic, TypeVar -from pydantic import UUID4, BaseModel +from pydantic import BaseModel from sqlalchemy import exc from sqlmodel import SQLModel, func, select from sqlmodel.ext.asyncio.session import AsyncSession @@ -18,7 +18,7 @@ UpdateSchemaType = TypeVar("UpdateSchemaType", bound=BaseModel) SchemaType = TypeVar("SchemaType", bound=BaseModel) T = TypeVar("T", bound=SQLModel) -PK = UUID4 +PK = Any class BaseRepository(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): @@ -35,9 +35,7 @@ def __init__(self, model: type[ModelType], db: AsyncSession | None = None): self.model = model self.db = db - async def get( - self, *, id: PK, db_session: AsyncSession | None = None - ) -> ModelType | None: + async def get(self, *, id: PK, db_session: AsyncSession | None = None) -> ModelType: """Get a single object by ID.""" session = self._get_db_session(db_session) response = await session.get(self.model, id) From 3dc90c37a8a927190ff3c7361588c749eda8d00d Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Fri, 30 Aug 2024 17:25:52 -0600 Subject: [PATCH 3/8] wip: write docs Signed-off-by: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> --- .github/workflows/ci.yml | 5 +- docs/api_doc/async_sql_model_middleware.md | 11 ++ docs/api_doc/base_repository.md | 18 +++ docs/api_doc/base_sql_model.md | 10 ++ docs/api_doc/base_timestamp_model.md | 11 ++ docs/api_doc/base_uuid_model.md | 10 ++ docs/concepts/filtering.md | 0 docs/concepts/middlewares.md | 32 +++++ docs/concepts/models.md | 120 ++++++++++++++++++ docs/concepts/pagination.md | 0 docs/concepts/repositories.md | 118 +++++++++++++++++ docs/extra/style.css | 15 +++ docs/index.md | 2 +- docs_src/middlewares/__init__.py | 0 docs_src/middlewares/main.py | 15 +++ docs_src/middlewares/middleware.py | 36 ++++++ docs_src/models/__init__.py | 0 docs_src/models/async_attrs_model.py | 33 +++++ docs_src/models/combined_model.py | 5 + docs_src/models/model.py | 8 ++ docs_src/models/timestamp_model.py | 5 + docs_src/models/uuid_model.py | 5 + docs_src/repositories/__init__.py | 0 docs_src/repositories/repository.py | 76 +++++++++++ mkdocs.yml | 71 +++++++++-- pyproject.toml | 5 + src/fastapi_async_sql/models.py | 27 +++- src/fastapi_async_sql/repositories.py | 13 +- tests/repositories/hero_repository.py | 4 +- tests/repositories/item_repository.py | 4 +- .../{hero_schema.py => hero_schemas.py} | 16 +-- .../{item_schema.py => item_schemas.py} | 6 +- tests/test_api.py | 12 +- tests/test_middlewares.py | 4 +- tests/test_repositories.py | 12 +- uv.lock | 99 +++++++++++++++ 36 files changed, 764 insertions(+), 44 deletions(-) create mode 100644 docs/api_doc/async_sql_model_middleware.md create mode 100644 docs/api_doc/base_repository.md create mode 100644 docs/api_doc/base_sql_model.md create mode 100644 docs/api_doc/base_timestamp_model.md create mode 100644 docs/api_doc/base_uuid_model.md create mode 100644 docs/concepts/filtering.md create mode 100644 docs/concepts/middlewares.md create mode 100644 docs/concepts/models.md create mode 100644 docs/concepts/pagination.md create mode 100644 docs/concepts/repositories.md create mode 100644 docs/extra/style.css create mode 100644 docs_src/middlewares/__init__.py create mode 100644 docs_src/middlewares/main.py create mode 100644 docs_src/middlewares/middleware.py create mode 100644 docs_src/models/__init__.py create mode 100644 docs_src/models/async_attrs_model.py create mode 100644 docs_src/models/combined_model.py create mode 100644 docs_src/models/model.py create mode 100644 docs_src/models/timestamp_model.py create mode 100644 docs_src/models/uuid_model.py create mode 100644 docs_src/repositories/__init__.py create mode 100644 docs_src/repositories/repository.py rename tests/schemas/{hero_schema.py => hero_schemas.py} (58%) rename tests/schemas/{item_schema.py => item_schemas.py} (66%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0db9364..2feceb0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: ci +name: Build Docs on: push: branches: @@ -29,4 +29,7 @@ jobs: pip install mkdocs-material pip install mkdocs-git-revision-date-localized-plugin pip install termynal + pip install mdx_include + pip install 'mkdocstrings[python]' + # TODO: Use uv to install docs optional dependencies - run: mkdocs gh-deploy --force diff --git a/docs/api_doc/async_sql_model_middleware.md b/docs/api_doc/async_sql_model_middleware.md new file mode 100644 index 0000000..9f71536 --- /dev/null +++ b/docs/api_doc/async_sql_model_middleware.md @@ -0,0 +1,11 @@ +# AsyncSQLModelMiddleware +The `AsyncSQLModelMiddleware` class is a FastAPI middleware that provides an asynchronous SQLModel session (`AsyncSession`) for each request. It ensures that each request has a database session available via `request.state.db`. + +::: fastapi_async_sql.middlewares.AsyncSQLModelMiddleware + options: + show_root_heading: true + merge_init_into_class: false + group_by_category: false + members: + - __init__ + - dispatch diff --git a/docs/api_doc/base_repository.md b/docs/api_doc/base_repository.md new file mode 100644 index 0000000..a0e393e --- /dev/null +++ b/docs/api_doc/base_repository.md @@ -0,0 +1,18 @@ +# BaseRepository +The `BaseRepository` class is a generic repository that provides default methods for Create, Read, Update, and Delete (CRUD) operations. It works with any SQLModel model and an asynchronous database session. + +::: fastapi_async_sql.repositories.BaseRepository + options: + show_root_heading: true + merge_init_into_class: false + group_by_category: false + members: + - __init__ + - get + - get_by_ids + - get_count + - get_multi + - get_multi_paginated + - create + - update + - remove diff --git a/docs/api_doc/base_sql_model.md b/docs/api_doc/base_sql_model.md new file mode 100644 index 0000000..f76c8fe --- /dev/null +++ b/docs/api_doc/base_sql_model.md @@ -0,0 +1,10 @@ +# BaseSQLModel +The `BaseSQLModel` class is an opinionated base class for SQLModel models. + +::: fastapi_async_sql.models.BaseSQLModel + options: + show_root_heading: true + merge_init_into_class: false + group_by_category: false + members: + - __tablename__ \ No newline at end of file diff --git a/docs/api_doc/base_timestamp_model.md b/docs/api_doc/base_timestamp_model.md new file mode 100644 index 0000000..05c5e06 --- /dev/null +++ b/docs/api_doc/base_timestamp_model.md @@ -0,0 +1,11 @@ +# BaseTimestampModel +The BaseTimestampModel class is a mixin that adds created_at and updated_at timestamp fields to your models. + +::: fastapi_async_sql.models.BaseTimestampModel + options: + show_root_heading: true + merge_init_into_class: false + group_by_category: false + members: + - created_at + - updated_at \ No newline at end of file diff --git a/docs/api_doc/base_uuid_model.md b/docs/api_doc/base_uuid_model.md new file mode 100644 index 0000000..49504b3 --- /dev/null +++ b/docs/api_doc/base_uuid_model.md @@ -0,0 +1,10 @@ +# BaseUUIDModel +The `BaseUUIDModel` class is a mixin that adds a UUID-based primary key field to your models. This field is automatically populated with a UUID4 value when a record is created. + +::: fastapi_async_sql.models.BaseUUIDModel + options: + show_root_heading: true + merge_init_into_class: false + group_by_category: false + members: + - id diff --git a/docs/concepts/filtering.md b/docs/concepts/filtering.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/concepts/middlewares.md b/docs/concepts/middlewares.md new file mode 100644 index 0000000..36430e2 --- /dev/null +++ b/docs/concepts/middlewares.md @@ -0,0 +1,32 @@ +# Middlewares + +This module provides middleware for integrating SQLModel with FastAPI using asynchronous SQLAlchemy sessions. The main class in this module is `AsyncSQLModelMiddleware`. + +## `AsyncSQLModelMiddleware` +??? api "API Documentation" + + [fastapi_async_sql.middlewares.AsyncSQLModelMiddleware](../api_doc/async_sql_model_middleware.md) +The `AsyncSQLModelMiddleware` class is a FastAPI middleware that provides an asynchronous SQLModel session (`AsyncSession`) for each request. +It ensures that each request has a database session available via `request.state.db`. + +### Usage Example +Here’s how you can integrate `AsyncSQLModelMiddleware` into your FastAPI application: + + +```python hl_lines="15 25-26 34" +{!./docs_src/middlewares/middleware.py!} +``` + +### How It Works +1. Initialization: The middleware is initialized with either a db_url or a custom_engine. It creates an async engine and an async session maker bound to the engine. +2. Request Handling: During request handling, the middleware creates an async session and attaches it to the request.state.db. This session is used throughout the request and is automatically disposed of after the request is processed. +3. Dependency Injection: You can access the session in your route handlers by extracting it from request.state.db using FastAPI's dependency injection system. + +### Example Configuration +To use `AsyncSQLModelMiddleware`, you can either pass a database URL or a custom SQLAlchemy async engine. Here’s an example configuration: + +```python hl_lines="11 15" +{!./docs_src/middlewares/main.py!} +``` + +With this setup, every request will have access to an async SQLAlchemy session, making it easy to interact with your database using SQLModel in an async manner. diff --git a/docs/concepts/models.md b/docs/concepts/models.md new file mode 100644 index 0000000..79fb303 --- /dev/null +++ b/docs/concepts/models.md @@ -0,0 +1,120 @@ +# Models + +The `models` module in the `fastapi-async-sql` package provides base classes that can be used to create SQLAlchemy models with asynchronous support in FastAPI applications. + +These base classes include functionality for automatic table naming, timestamp fields, and UUID-based primary keys. + +## Base Classes + +### `BaseSQLModel` +??? api "API Documentation" + + [fastapi_async_sql.models.BaseSQLModel](../api_doc/base_sql_model.md) + +The `BaseSQLModel` class is a base model that extends SQLModel and adds asynchronous capabilities via SQLAlchemy's `AsyncAttrs`. + +It automatically generates the table name based on the class name and provides a configuration that supports camelCase aliasing, assignment validation, and strict field handling. + +#### Example Usage + +```python +{!./docs_src/models/model.py!} +``` + +#### Features +- Automatic Table Naming: The table name is automatically generated from the class name and converted to snake_case plural form. +- Pydantic Configuration: Configured to use camelCase for JSON serialization, validate field assignments, populate fields by name, and forbid extra fields. + +#### How It Works +- Table Naming: The `__tablename__` attribute is generated using the `to_snake_plural` function, which converts the class name from PascalCase to snake_case and pluralizes it. For example, a class named Item would have a table name items. +- Pydantic Config: The `model_config` attribute defines how the model behaves with Pydantic, ensuring proper aliasing and validation. + +### `BaseTimestampModel` +??? api "API Documentation" + + [fastapi_async_sql.models.BaseTimestampModel](../api_doc/base_timestamp_model.md) + +The `BaseTimestampModel` class is a mixin that adds `created_at` and `updated_at` timestamp fields to your models. These fields are automatically populated with the current UTC time when a record is created or updated. +#### Example Usage +```python +{!./docs_src/models/timestamp_model.py!} +``` +#### Features +- Automatically populates the `created_at` field with the current UTC time when a record is created. +- Automatically updates the `updated_at` field with the current UTC time when a record is updated. + +#### How It Works +- `created_at`: Uses `default_factory` to set the current UTC time when the record is created. +- `updated_at`: Uses onupdate to set the current UTC time whenever the record is updated. If the record is not updated, the field remains None. + +### `BaseUUIDModel` +??? api "API Documentation" + + [fastapi_async_sql.models.BaseUUIDModel](../api_doc/base_uuid_model.md) + +The `BaseUUIDModel` class is a mixin that adds a UUID-based primary key field to your models. This field is automatically populated with a UUID4 value when a record is created. + +#### Example Usage + +```python +{!./docs_src/models/uuid_model.py!} +``` + +#### Features +- Automatically generates a UUID4 value for the `id` field when a record is created. +- Ensures that each record has a unique UUID as the primary key. + +#### How It Works +- `UUID Generation`: The `id` field is set to use the `uuid4` function as the default value, ensuring a unique UUID is generated for each record. + + +## Combining Base Classes +You can combine the base classes to create models with multiple features. For example, you can create a model with both timestamp fields and a UUID-based primary key. + +### Example Usage +```python +{!./docs_src/models/combined_model.py!} +``` + +### Features +This model would have the following features: + +- Automatic table naming based on the class name. +- CamelCase aliasing and field validation. +- `created_at` and `updated_at` timestamp fields. +- `id` primary key field with UUID4 values. +- Automatic population of timestamp and UUID fields on record creation. +- Proper JSON serialization and validation behavior. +- Strict field handling and aliasing. + +## Using Models with `AsyncAttrs` to Prevent Implicit I/O + +### Understanding Implicit I/O + +In asynchronous applications, it's crucial to avoid implicit I/O operations, especially when working with SQLAlchemy's ORM and lazy-loading relationships. Implicit I/O can occur when you access relationship attributes or deferred columns that haven't been loaded yet. Under traditional asyncio, accessing these attributes directly can lead to errors because the I/O operation needed to fetch the data is not allowed to occur implicitly. + +### What is `AsyncAttrs`? + +`AsyncAttrs` is a mixin provided by SQLAlchemy that helps manage these situations by enabling attributes to be accessed in an awaitable manner. When you use `AsyncAttrs`, any attribute that might trigger a lazy load or deferred column access can be accessed using the `awaitable_attrs` attribute. This ensures that any required I/O operation is explicitly awaited, thus avoiding implicit I/O errors. + +### How to Use `AsyncAttrs` in `fastapi-async-sql` + +The `BaseSQLModel` class in `fastapi-async-sql` includes the `AsyncAttrs` mixin, which makes it easy to work with asynchronous I/O when using SQLModel. + +Here’s how you can use it: + +1. **Define your Models**: Your models should inherit from `BaseSQLModel`, which already includes the `AsyncAttrs` mixin. + +2. **Access Relationships**: When accessing a relationship attribute that is lazy-loaded, use the `awaitable_attrs` accessor to explicitly await the attribute. + +#### Example Usage + +```python hl_lines="31" +{!./docs_src/models/async_attrs_model.py!} +``` + +## Utility Functions +The models module relies on utility functions to handle string conversions and pluralization. These functions are used to generate table names and convert between different naming conventions. + +- `to_camel`: Converts a string from snake_case to camelCase. +- `to_snake_plural`: Converts a string from PascalCase to snake_case plural form. diff --git a/docs/concepts/pagination.md b/docs/concepts/pagination.md new file mode 100644 index 0000000..e69de29 diff --git a/docs/concepts/repositories.md b/docs/concepts/repositories.md new file mode 100644 index 0000000..1dbf20c --- /dev/null +++ b/docs/concepts/repositories.md @@ -0,0 +1,118 @@ +# Repositories + +The `repositories` module in the `fastapi-async-sql` package provides a generic repository pattern implementation to simplify CRUD operations for SQLModel models in FastAPI applications. This module abstracts common database operations, allowing developers to focus on business logic rather than repetitive CRUD operations. + +## BaseRepository +??? api "API Documentation" + + [fastapi_async_sql.repositories.BaseRepository](../api_doc/base_repository.md) + +The `BaseRepository` class is a generic repository that provides default methods for Create, Read, Update, and Delete (CRUD) operations. It works with any SQLModel model and an asynchronous database session. + + +### Example Usage + +Here’s an example of how to use the `BaseRepository` in a FastAPI application: +#### Create SQLModel models +The first step is to define your SQLModel models. These models represent the tables in your database and define the structure of your data. +```python +# Code above omitted πŸ‘† + +{!./docs_src/repositories/repository.py[ln:7]!} + + +{!./docs_src/repositories/repository.py[ln:15-19]!} + +# Code below omitted πŸ‘‡ +``` +/// details | πŸ‘€ Full file preview +```python + +{!./docs_src/repositories/repository.py!} +``` +/// + +#### Create schemas (data models) +Next, create Pydantic schemas to validate and serialize your data. These schemas define the structure of the data that will be sent to and received from your FastAPI endpoints. +```python +# Code above omitted πŸ‘† + +{!./docs_src/repositories/repository.py[ln:22-30]!} + +# Code below omitted πŸ‘‡ +``` +/// details | πŸ‘€ Full file preview +```python + +{!./docs_src/repositories/repository.py!} +``` +/// + +#### Create a repository +Finally, create a repository for your model by extending the `BaseRepository` class. This repository will handle all the CRUD operations for your model. +```python +# Code above omitted πŸ‘† + +{!./docs_src/repositories/repository.py[ln:33-35]!} + +# Code below omitted πŸ‘‡ +``` +/// details | πŸ‘€ Full file preview +```python + +{!./docs_src/repositories/repository.py!} +``` +/// + +#### Create a dependency +To use the repository in your FastAPI application, you can create a dependency that initializes the repository with the database session. +```python +# Code above omitted πŸ‘† + +{!./docs_src/repositories/repository.py[ln:37-41]!} + +# Code below omitted πŸ‘‡ +``` + +/// details | πŸ‘€ Full file preview +```python + +{!./docs_src/repositories/repository.py!} +``` +/// + +#### Use the repository +You can now use the repository in your FastAPI application to interact with your database. The repository provides methods for creating, reading, updating, and deleting objects in the database. +```python +# Code above omitted πŸ‘† + +{!./docs_src/repositories/repository.py[ln:44-]!} + +``` +/// details | πŸ‘€ Full file preview +```python + +{!./docs_src/repositories/repository.py!} +``` +/// + +## Exception Handling + +The `repositories` module makes use of custom exceptions to handle specific error scenarios: + +- `CreateObjectError`: Raised when there is an error creating an object in the database. +- `ObjectNotFoundError`: Raised when an object is not found in the database by its primary key. + +These exceptions help provide more meaningful error messages and make it easier to debug issues in your application. + +## Filtering and Pagination + +The `get_multi` and `get_multi_paginated` methods support filtering and pagination through the `Filter` and `Params` classes, respectively. These features make it easier to manage large datasets and retrieve only the data you need. + +## Additional resources +- [SQLModel documentation](https://sqlmodel.tiangolo.com/tutorial/fastapi/multiple-models/) +- [Pydantic documentation](https://docs.pydantic.dev/latest/) +- [FastAPI documentation](https://fastapi.tiangolo.com/) +- [SQLAlchemy documentation](https://docs.sqlalchemy.org/) +- [FastAPI Pagination documentation](https://uriyyo-fastapi-pagination.netlify.app/) +- [FastAPI Filter documentation](https://fastapi-filter.netlify.app/) \ No newline at end of file diff --git a/docs/extra/style.css b/docs/extra/style.css new file mode 100644 index 0000000..17b52b1 --- /dev/null +++ b/docs/extra/style.css @@ -0,0 +1,15 @@ +/* API documentation link admonition */ +:root { + --md-admonition-icon--api: url('data:image/svg+xml;charset=utf-8,') +} +.md-typeset .admonition.api, .md-typeset details.api { + border-color: #448aff; +} +.md-typeset .api > .admonition-title, .md-typeset .api > summary { + background-color: #448aff1a; +} +.md-typeset .api > .admonition-title::before, .md-typeset .api > summary::before { + background-color: #448aff; + -webkit-mask-image: var(--md-admonition-icon--api); + mask-image: var(--md-admonition-icon--api); +} diff --git a/docs/index.md b/docs/index.md index 24fbe28..33b583c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -12,7 +12,7 @@ uv pip install fastapi-async-sql ## Features -- AsyncSQLAlchemyMiddleware: A middleware to handle database connections with AsyncSQLAlchemy +- AsyncSQLModelMiddleware: A middleware to handle database connections with AsyncSQLAlchemy - [SQLModel](https://sqlmodel.tiangolo.com/): A library to handle database models with Pydantic and SQLAlchemy - Base models for `SQLModel`: - `BaseSQLModel`: A opinionated base model for SQLAlchemy models diff --git a/docs_src/middlewares/__init__.py b/docs_src/middlewares/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docs_src/middlewares/main.py b/docs_src/middlewares/main.py new file mode 100644 index 0000000..5f85314 --- /dev/null +++ b/docs_src/middlewares/main.py @@ -0,0 +1,15 @@ +from fastapi import FastAPI +from sqlalchemy.ext.asyncio import create_async_engine + +from fastapi_async_sql.middlewares import AsyncSQLModelMiddleware + +app = FastAPI() + +DATABASE_URL = "sqlite+aiosqlite:///./test.db" + +# Option 1: Using database URL +app.add_middleware(AsyncSQLModelMiddleware, db_url=DATABASE_URL) + +# Option 2: Using custom engine +engine = create_async_engine(DATABASE_URL) +app.add_middleware(AsyncSQLModelMiddleware, custom_engine=engine) diff --git a/docs_src/middlewares/middleware.py b/docs_src/middlewares/middleware.py new file mode 100644 index 0000000..adb9a23 --- /dev/null +++ b/docs_src/middlewares/middleware.py @@ -0,0 +1,36 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Request +from sqlmodel import Field, SQLModel, select +from sqlmodel.ext.asyncio.session import AsyncSession + +from fastapi_async_sql.middlewares import AsyncSQLModelMiddleware + +app = FastAPI() + +# Define the database URL +DATABASE_URL = "sqlite+aiosqlite:///./test.db" + +# Add the middleware to the FastAPI app +app.add_middleware(AsyncSQLModelMiddleware, db_url=DATABASE_URL) + + +# Example model +class Item(SQLModel, table=True): + id: int | None = Field(default=None, primary_key=True) + name: str + + +# Dependency that extracts the session from the request state +async def get_session(request: Request) -> AsyncSession: + return request.state.db + + +# Annotate the get_session dependency, so you can use it in routes +AsyncSessionDependency = Annotated[AsyncSession, Depends(get_session)] + + +@app.get("/items") +async def read_items(session: AsyncSessionDependency): + items = await session.exec(select(Item)) + return items diff --git a/docs_src/models/__init__.py b/docs_src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docs_src/models/async_attrs_model.py b/docs_src/models/async_attrs_model.py new file mode 100644 index 0000000..34325d6 --- /dev/null +++ b/docs_src/models/async_attrs_model.py @@ -0,0 +1,33 @@ +from fastapi import FastAPI, Request +from sqlmodel import Field, Relationship + +from fastapi_async_sql.middlewares import AsyncSQLModelMiddleware +from fastapi_async_sql.models import BaseSQLModel + +app = FastAPI() +app.add_middleware(AsyncSQLModelMiddleware, db_url="sqlite+aiosqlite:///./test.db") + + +class User(BaseSQLModel, table=True): + id: int = Field(default=None, primary_key=True) + name: str + posts: list["Post"] = Relationship(back_populates="user") + + +class Post(BaseSQLModel, table=True): + id: int = Field(default=None, primary_key=True) + title: str + content: str + user_id: int = Field(foreign_key="user.id") + user: "User" = Relationship(back_populates="posts") + + +# Example of accessing the relationship +@app.get("/users/{user_id}/posts") +async def get_user_posts(request: Request, user_id: int): + user = await request.state.db.get(User, user_id) + + # Accessing the lazy-loaded relationship with awaitable_attrs + posts = await user.awaitable_attrs.posts + + return posts diff --git a/docs_src/models/combined_model.py b/docs_src/models/combined_model.py new file mode 100644 index 0000000..46559bd --- /dev/null +++ b/docs_src/models/combined_model.py @@ -0,0 +1,5 @@ +from fastapi_async_sql.models import BaseSQLModel, BaseTimestampModel, BaseUUIDModel + + +class Item(BaseSQLModel, BaseTimestampModel, BaseUUIDModel, table=True): + name: str diff --git a/docs_src/models/model.py b/docs_src/models/model.py new file mode 100644 index 0000000..63c2b3a --- /dev/null +++ b/docs_src/models/model.py @@ -0,0 +1,8 @@ +from sqlmodel import Field + +from fastapi_async_sql.models import BaseSQLModel + + +class Item(BaseSQLModel, table=True): + id: int = Field(primary_key=True) + name: str diff --git a/docs_src/models/timestamp_model.py b/docs_src/models/timestamp_model.py new file mode 100644 index 0000000..a7c6506 --- /dev/null +++ b/docs_src/models/timestamp_model.py @@ -0,0 +1,5 @@ +from fastapi_async_sql.models import BaseSQLModel, BaseTimestampModel + + +class Item(BaseSQLModel, BaseTimestampModel, table=True): + name: str diff --git a/docs_src/models/uuid_model.py b/docs_src/models/uuid_model.py new file mode 100644 index 0000000..1a0284b --- /dev/null +++ b/docs_src/models/uuid_model.py @@ -0,0 +1,5 @@ +from fastapi_async_sql.models import BaseSQLModel, BaseUUIDModel + + +class Item(BaseSQLModel, BaseUUIDModel, table=True): + name: str diff --git a/docs_src/repositories/__init__.py b/docs_src/repositories/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/docs_src/repositories/repository.py b/docs_src/repositories/repository.py new file mode 100644 index 0000000..d65ce99 --- /dev/null +++ b/docs_src/repositories/repository.py @@ -0,0 +1,76 @@ +from typing import Annotated + +from fastapi import Depends, FastAPI, Request +from pydantic import UUID4 + +from fastapi_async_sql.middlewares import AsyncSQLModelMiddleware +from fastapi_async_sql.models import BaseSQLModel +from fastapi_async_sql.repositories import BaseRepository +from fastapi_async_sql.utils.partial import optional + +app = FastAPI() +app.add_middleware(AsyncSQLModelMiddleware, db_url="sqlite+aiosqlite:///./test.db") + + +class ItemBase(BaseSQLModel): + name: str + + +class Item(ItemBase, table=True): ... + + +class ItemCreateSchema(ItemBase): ... + + +@optional() +class ItemUpdateSchema(ItemBase): ... + + +class ItemReadSchema(ItemBase): + id: UUID4 + + +class ItemRepository(BaseRepository[Item, ItemCreateSchema, ItemUpdateSchema]): + pass + + +def get_item_repository(request: Request) -> ItemRepository: + return ItemRepository(Item, request.state.db) + + +ItemRepositoryDependency = Annotated[ItemRepository, Depends(get_item_repository)] + + +@app.get("/items/", response_model=list[ItemReadSchema]) +async def read_items( + item_repository: ItemRepositoryDependency, +): + return await item_repository.get_multi() + + +@app.post("/items/", response_model=ItemReadSchema) +async def create_item( + item_repository: ItemRepositoryDependency, + item_in: ItemCreateSchema, +): + return await item_repository.create(obj_in=item_in) + + +@app.patch("/items/{item_id}", response_model=ItemReadSchema) +async def update_item( + item_repository: ItemRepositoryDependency, + item_id: UUID4, + item_new: ItemUpdateSchema, +): + item_current = await item_repository.get(id=item_id) + return await item_repository.update(obj_current=item_current, obj_new=item_new) + + +@app.delete("/items/{item_id}") +async def delete_item( + item_repository: ItemRepositoryDependency, + item_id: UUID4, +): + item = await item_repository.get(id=item_id) + await item_repository.remove(id=item.id) + return None diff --git a/mkdocs.yml b/mkdocs.yml index 52ba6e9..5053552 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -5,25 +5,74 @@ repo_name: pygrok/fastapi-async-sql theme: name: material palette: - primary: teal + primary: 'pink' icon: repo: fontawesome/brands/github + admonition: + api: fontawesome/solid/note-sticky features: - content.code.copy + - navigation.tabs plugins: + - search + - mkdocstrings - termynal: title: bash buttons: macos prompt_literal_start: - "$" - - git-revision-date-localized: - enable_creation_date: true - type: date markdown_extensions: - - pymdownx.highlight: - anchor_linenums: true - line_spans: __span - pygments_lang_class: true - - pymdownx.inlinehilite - - pymdownx.snippets - - pymdownx.superfences \ No newline at end of file + toc: + permalink: True + admonition: + # Python Markdown Extensions + pymdownx.betterem: + smart_enable: all + pymdownx.caret: + pymdownx.highlight: + line_spans: __span + pymdownx.inlinehilite: + pymdownx.keys: + pymdownx.mark: + pymdownx.details: + pymdownx.superfences: + custom_fences: + - name: mermaid + class: mermaid + format: "!!python/name:pymdownx.superfences.fence_code_format" + pymdownx.tilde: + + # pymdownx blocks + pymdownx.blocks.admonition: + types: + - note + - attention + - caution + - danger + - error + - tip + - hint + - warning + # Custom types + - info + pymdownx.blocks.details: + pymdownx.blocks.tab: + alternate_style: True + mdx_include: + base_path: . +extra_css: + - 'extra/style.css' +nav: + - FastAPI Async SQL: index.md + - Concepts: + - Middlewares: concepts/middlewares.md + - Models: concepts/models.md + - Repositories: concepts/repositories.md + - Pagination: concepts/pagination.md + - Filtering: concepts/filtering.md + - API Documentation: + - AsyncSQLModelMiddleware: api_doc/async_sql_model_middleware.md + - BaseSQLModel: api_doc/base_sql_model.md + - BaseTimestampModel: api_doc/base_timestamp_model.md + - BaseUUIDModel: api_doc/base_uuid_model.md + - BaseRepository: api_doc/base_repository.md diff --git a/pyproject.toml b/pyproject.toml index 6be53c9..937fb4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ docs = [ "mkdocs-material>=9.5.33", "mkdocs-git-revision-date-localized-plugin>=1.2.7", "termynal>=0.12.1", + "mdx-include>=1.4.2", + "mkdocstrings[python]>=0.25.2", ] [build-system] @@ -66,6 +68,9 @@ lint.ignore = [ "D203", # 1 blank line required before class docstring "D212", # Multi-line docstring summary should start at the first line ] +[tool.ruff.lint.per-file-ignores] +"docs/*" = ["D"] +"docs_src/*" = ["D"] [tool.ruff.lint.isort] known-local-folder = ["tests", "cmd", "src", "pkg"] diff --git a/src/fastapi_async_sql/models.py b/src/fastapi_async_sql/models.py index 51bf33e..2434870 100644 --- a/src/fastapi_async_sql/models.py +++ b/src/fastapi_async_sql/models.py @@ -1,4 +1,4 @@ -"""Base model.""" +"""Base models.""" from datetime import datetime, timezone from uuid import UUID, uuid4 @@ -12,6 +12,17 @@ class BaseSQLModel(AsyncAttrs, SQLModel): + """Base SQL model with automatic __tablename__ generation. + + /// info | Usage Documentation + [Models](../concepts/models.md#basesqlmodel) + /// + + Attributes: + __tablename__ (str): The table name for the model. + model_config (ConfigDict): The configuration for the model. + """ + @declared_attr # type: ignore def __tablename__(cls) -> str: """Generate __tablename__ automatically.""" @@ -26,6 +37,13 @@ def __tablename__(cls) -> str: class BaseTimestampModel: + """Base model with created_at and updated_at fields. + + /// info | Usage Documentation + [Models](../concepts/models.md#basetimestampmodel) + /// + """ + created_at: AwareDatetime = Field( default_factory=lambda: datetime.now(tz=timezone.utc), sa_type=TIMESTAMP(timezone=True), @@ -39,6 +57,13 @@ class BaseTimestampModel: class BaseUUIDModel: + """Base model with UUID primary key. + + /// info | Usage Documentation + [Models](../concepts/models.md#baseuuidmodel) + /// + """ + id: UUID = Field( default_factory=uuid4, primary_key=True, diff --git a/src/fastapi_async_sql/repositories.py b/src/fastapi_async_sql/repositories.py index 26eec7a..cc13351 100644 --- a/src/fastapi_async_sql/repositories.py +++ b/src/fastapi_async_sql/repositories.py @@ -22,7 +22,18 @@ class BaseRepository(Generic[ModelType, CreateSchemaType, UpdateSchemaType]): - def __init__(self, model: type[ModelType], db: AsyncSession | None = None): + """Base Repository with default methods to Create, Read, Update, Delete (CRUD). + + /// info | Usage Documentation + [Repositories](../concepts/repositories.md#baserepository) + /// + + Attributes: + model (type[ModelType]): The model to be used in the Repository. + db (AsyncSession | None): The database session to be used. Defaults to None. + """ + + def __init__(self, model: type[ModelType], db: AsyncSession | None = None) -> None: """Repository with default methods to Create, Read, Update, Delete (CRUD). Args: diff --git a/tests/repositories/hero_repository.py b/tests/repositories/hero_repository.py index 09ee9a9..bdebdca 100644 --- a/tests/repositories/hero_repository.py +++ b/tests/repositories/hero_repository.py @@ -3,8 +3,8 @@ from fastapi_async_sql.repositories import BaseRepository from ..models.hero_model import Hero -from ..schemas.hero_schema import IHeroCreate, IHeroUpdate +from ..schemas.hero_schemas import HeroCreateSchema, HeroUpdateSchema -class HeroRepository(BaseRepository[Hero, IHeroCreate, IHeroUpdate]): +class HeroRepository(BaseRepository[Hero, HeroCreateSchema, HeroUpdateSchema]): pass diff --git a/tests/repositories/item_repository.py b/tests/repositories/item_repository.py index 1d10244..e21340e 100644 --- a/tests/repositories/item_repository.py +++ b/tests/repositories/item_repository.py @@ -3,8 +3,8 @@ from fastapi_async_sql.repositories import BaseRepository from ..models.item_model import Item -from ..schemas.item_schema import IItemCreate, IItemUpdate +from ..schemas.item_schemas import ItemCreateSchema, ItemUpdateSchema -class ItemRepository(BaseRepository[Item, IItemCreate, IItemUpdate]): +class ItemRepository(BaseRepository[Item, ItemCreateSchema, ItemUpdateSchema]): pass diff --git a/tests/schemas/hero_schema.py b/tests/schemas/hero_schemas.py similarity index 58% rename from tests/schemas/hero_schema.py rename to tests/schemas/hero_schemas.py index e2c87b9..848bb8f 100644 --- a/tests/schemas/hero_schema.py +++ b/tests/schemas/hero_schemas.py @@ -9,30 +9,30 @@ from ..models.team_model import TeamBase -class IHeroCreate(HeroBase): +class HeroCreateSchema(HeroBase): pass @optional() -class IHeroUpdate(HeroBase): +class HeroUpdateSchema(HeroBase): team_id: UUID4 | None = None item_id: UUID4 | None = None -class IHeroRead(HeroBase): +class HeroReadSchema(HeroBase): id: UUID4 -class HeroTeamRead(TeamBase): +class _HeroTeamReadSchema(TeamBase): id: UUID4 -class ItemTeamRead(ItemBase): +class _ItemTeamReadSchema(ItemBase): id: UUID4 created_by_id: UUID4 -class IHeroReadWithTeam(IHeroRead): +class HeroReadWithTeamSchema(HeroReadSchema): id: UUID4 - team: HeroTeamRead = None - item: ItemTeamRead = None + team: _HeroTeamReadSchema = None + item: _ItemTeamReadSchema = None diff --git a/tests/schemas/item_schema.py b/tests/schemas/item_schemas.py similarity index 66% rename from tests/schemas/item_schema.py rename to tests/schemas/item_schemas.py index 3371d41..bd46a81 100644 --- a/tests/schemas/item_schema.py +++ b/tests/schemas/item_schemas.py @@ -7,13 +7,13 @@ from ..models.item_model import ItemBase -class IItemCreate(ItemBase): ... +class ItemCreateSchema(ItemBase): ... @optional() -class IItemUpdate(ItemBase): ... +class ItemUpdateSchema(ItemBase): ... -class IItemRead(ItemBase): +class ItemReadSchema(ItemBase): id: UUID4 created_by_id: UUID4 diff --git a/tests/test_api.py b/tests/test_api.py index 73d33eb..14da301 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -17,7 +17,7 @@ from tests.dependencies import AnnotatedRepositoryHero from tests.models.hero_model import Hero from tests.models.team_model import Team -from tests.schemas.hero_schema import IHeroRead, IHeroReadWithTeam +from tests.schemas.hero_schemas import HeroReadSchema, HeroReadWithTeamSchema @pytest.mark.parametrize( @@ -44,7 +44,7 @@ async def test_get_paginated_response( @app.get("/heroes") async def get_heroes( repository: AnnotatedRepositoryHero, params: Params = Depends() - ) -> Page[IHeroRead]: + ) -> Page[HeroReadSchema]: return await repository.get_multi_paginated(page_params=params) response = await client.get("/heroes", params=query_params) @@ -303,7 +303,7 @@ async def get_heroes( repository: AnnotatedRepositoryHero, params: Params = Depends(), filter_by: HeroFilter = FilterDepends(HeroFilter), - ) -> list[IHeroRead]: + ) -> list[HeroReadSchema]: query = select(Hero).outerjoin(Team) return await repository.get_multi( query=query, page_params=params, filter_by=filter_by @@ -314,7 +314,7 @@ async def get_heroes_by_alias( repository: AnnotatedRepositoryHero, params: Params = Depends(), filter_by: HeroFilterByAlias = FilterDepends(HeroFilterByAlias, by_alias=True), - ) -> list[IHeroRead]: + ) -> list[HeroReadSchema]: return await repository.get_multi(page_params=params, filter_by=filter_by) response = await client.get(f"{endpoint}?{urlencode(filter_clause)}") @@ -332,7 +332,7 @@ async def test_get_heroes_with_relationships( @app.get("/heroes") async def get_heroes( repository: AnnotatedRepositoryHero, - ) -> list[IHeroReadWithTeam]: + ) -> list[HeroReadWithTeamSchema]: query = ( select(Hero) .options(selectinload(Hero.team)) @@ -375,7 +375,7 @@ async def test_get_hero_with_relationships_with_lazy_loading( @app.get("/heroes/{hero_id}") async def get_heroes( hero_id: UUID4, repository: AnnotatedRepositoryHero - ) -> IHeroReadWithTeam: + ) -> HeroReadWithTeamSchema: response = await repository.get(id=hero_id) response.item = await response.awaitable_attrs.item response.team = await response.awaitable_attrs.team diff --git a/tests/test_middlewares.py b/tests/test_middlewares.py index ab2e50d..7a7754d 100644 --- a/tests/test_middlewares.py +++ b/tests/test_middlewares.py @@ -17,7 +17,7 @@ from tests.models.hero_model import Hero from tests.models.item_model import Item from tests.models.team_model import Team -from tests.schemas.hero_schema import IHeroRead +from tests.schemas.hero_schemas import HeroReadSchema @pytest.fixture(scope="function") @@ -143,7 +143,7 @@ async def test_async_sqlalchemy_middleware_db_session_commit( """Test that the middleware correctly commits the session.""" @app_with_db_middleware.post( - "/heroes", response_model=IHeroRead, status_code=status.HTTP_201_CREATED + "/heroes", response_model=HeroReadSchema, status_code=status.HTTP_201_CREATED ) async def create_hero(request: Request): hero = Hero( diff --git a/tests/test_repositories.py b/tests/test_repositories.py index 1c9b223..b1c9ee3 100644 --- a/tests/test_repositories.py +++ b/tests/test_repositories.py @@ -13,8 +13,8 @@ from tests.models.item_model import Item from tests.models.team_model import Team from tests.repositories import HeroRepository, ItemRepository -from tests.schemas.hero_schema import IHeroCreate, IHeroUpdate -from tests.schemas.item_schema import IItemCreate +from tests.schemas.hero_schemas import HeroCreateSchema, HeroUpdateSchema +from tests.schemas.item_schemas import ItemCreateSchema async def test_create_hero( @@ -24,7 +24,7 @@ async def test_create_hero( hero_repository: HeroRepository, ): """Test create hero.""" - hero_data = IHeroCreate( + hero_data = HeroCreateSchema( name="Test Hero", age=30, secret_identity="Test Identity", # nosec: B106 @@ -44,7 +44,7 @@ async def test_create_item_with_extra_data( item_repository: ItemRepository, ): """Test create item with extra data.""" - item_data = IItemCreate(name="Test Item") + item_data = ItemCreateSchema(name="Test Item") created_by_id = uuid4() created_item = await item_repository.create( obj_in=item_data, created_by_id=created_by_id @@ -61,7 +61,7 @@ async def test_create_hero_duplicate( hero_repository: HeroRepository, ): """Test create hero with duplicate data.""" - hero_data = IHeroCreate( + hero_data = HeroCreateSchema( name=heroes[0].name, age=30, secret_identity="Test Identity", # nosec: B106 @@ -124,7 +124,7 @@ async def test_update_hero( """Test update hero.""" hero = heroes[0] assert hero.updated_at is None - update_data = IHeroUpdate(name="Updated Hero") + update_data = HeroUpdateSchema(name="Updated Hero") updated_hero = await hero_repository.update(obj_current=hero, obj_new=update_data) assert updated_hero.name == "Updated Hero" assert isinstance(updated_hero.updated_at, datetime) diff --git a/uv.lock b/uv.lock index 1201470..ef25005 100644 --- a/uv.lock +++ b/uv.lock @@ -170,6 +170,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9f/b0/e0dca6da9170aefc07515cce067b97178cefafb512d00a87a1c717d2efd5/coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657", size = 211453 }, ] +[[package]] +name = "cyclic" +version = "1.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bf/9f/becc4fea44301f232e4eba17752001bd708e3c042fef37a72b9af7ddf4b5/cyclic-1.0.0.tar.gz", hash = "sha256:ecddd56cb831ee3e6b79f61ecb0ad71caee606c507136867782911aa01c3e5eb", size = 2167 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/c0/9f59d2ebd9d585e1681c51767eb138bcd9d0ea770f6fc003cd875c7f5e62/cyclic-1.0.0-py3-none-any.whl", hash = "sha256:32d8181d7698f426bce6f14f4c3921ef95b6a84af9f96192b59beb05bc00c3ed", size = 2547 }, +] + [[package]] name = "distlib" version = "0.3.8" @@ -231,8 +240,10 @@ dependencies = [ [package.optional-dependencies] docs = [ + { name = "mdx-include" }, { name = "mkdocs-git-revision-date-localized-plugin" }, { name = "mkdocs-material" }, + { name = "mkdocstrings", extra = ["python"] }, { name = "termynal" }, ] lint = [ @@ -259,8 +270,10 @@ requires-dist = [ { name = "fastapi-pagination", specifier = ">=0.12.26,<0.13.0" }, { name = "httpx", marker = "extra == 'test'", specifier = ">=0.27.2" }, { name = "inflect", specifier = ">=7.3.1,<8.0.0" }, + { name = "mdx-include", marker = "extra == 'docs'", specifier = ">=1.4.2" }, { name = "mkdocs-git-revision-date-localized-plugin", marker = "extra == 'docs'", specifier = ">=1.2.7" }, { name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.5.33" }, + { name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = ">=0.25.2" }, { name = "pre-commit", marker = "extra == 'lint'", specifier = ">=3.8.0" }, { name = "pytest", marker = "extra == 'test'", specifier = ">=8.3.2" }, { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.24.0" }, @@ -365,6 +378,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/53/80/3d94d5999b4179d91bcc93745d1b0815b073d61be79dd546b840d17adb18/greenlet-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:bba5387a6975598857d86de9eac14210a49d554a77eb8261cc68b7d082f78ce2", size = 293635 }, ] +[[package]] +name = "griffe" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/92/79/cbe9be5ac67bfd87c48b5b2d2fa170bf8c33b38d8661d9d1849f038ab1f9/griffe-1.2.0.tar.gz", hash = "sha256:1c9f6ef7455930f3f9b0c4145a961c90385d1e2cbc496f7796fbff560ec60d31", size = 381349 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/0b/5caa5617b63535fe1e0abc23af92bd1e6df4bd3d5b72bfe2c675d4770235/griffe-1.2.0-py3-none-any.whl", hash = "sha256:a8b2fcb1ecdc5a412e646b0b4375eb20a5d2eac3a11dd8c10c56967a4097663c", size = 126930 }, +] + [[package]] name = "h11" version = "0.14.0" @@ -495,6 +520,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/3f/14/c3554d512d5f9100a95e737502f4a2323a1959f6d0d01e0d0997b35f7b10/MarkupSafe-2.1.5-cp312-cp312-win_amd64.whl", hash = "sha256:823b65d8706e32ad2df51ed89496147a42a2a6e01c13cfb6ffb8b1e92bc910bb", size = 17127 }, ] +[[package]] +name = "mdx-include" +version = "1.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cyclic" }, + { name = "markdown" }, + { name = "rcslice" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bf/f0/f395a9cf164471d3c7bbe58cbd64d74289575a8b85a962b49a804ab7ed34/mdx_include-1.4.2.tar.gz", hash = "sha256:992f9fbc492b5cf43f7d8cb4b90b52a4e4c5fdd7fd04570290a83eea5c84f297", size = 15051 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/40/6844997dee251103c5a4c4eb0d1d2f2162b7c29ffc4e86de3cd68d269be2/mdx_include-1.4.2-py3-none-any.whl", hash = "sha256:cfbeadd59985f27a9b70cb7ab0a3d209892fe1bb1aa342df055e0b135b3c9f34", size = 11591 }, +] + [[package]] name = "mergedeep" version = "1.3.4" @@ -528,6 +567,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b8/c0/930dcf5a3e96b9c8e7ad15502603fc61d495479699e2d2c381e3d37294d1/mkdocs-1.6.0-py3-none-any.whl", hash = "sha256:1eb5cb7676b7d89323e62b56235010216319217d4af5ddc543a91beb8d125ea7", size = 3862264 }, ] +[[package]] +name = "mkdocs-autorefs" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markdown" }, + { name = "markupsafe" }, + { name = "mkdocs" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/47/24/7d09b72b470d5dd33ed0c6722c7038ece494ab7dc5e72adbfeaf945276f6/mkdocs_autorefs-1.1.0.tar.gz", hash = "sha256:f2fd43b11f66284bd014f9b542a05c8ecbfaad4e0d7b30b68584788217b6c656", size = 36989 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/84/8e4dd669766f864482d2edcc44c7f07f4f91d73414c9a3e33b230a59f2cf/mkdocs_autorefs-1.1.0-py3-none-any.whl", hash = "sha256:492ac42f50214e81565e968f8cb0df9aba9d981542b9e7121b8f8ae9407fe6eb", size = 14417 }, +] + [[package]] name = "mkdocs-get-deps" version = "0.2.0" @@ -588,6 +641,43 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5b/54/662a4743aa81d9582ee9339d4ffa3c8fd40a4965e033d77b9da9774d3960/mkdocs_material_extensions-1.3.1-py3-none-any.whl", hash = "sha256:adff8b62700b25cb77b53358dad940f3ef973dd6db797907c49e3c2ef3ab4e31", size = 8728 }, ] +[[package]] +name = "mkdocstrings" +version = "0.25.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "jinja2" }, + { name = "markdown" }, + { name = "markupsafe" }, + { name = "mkdocs" }, + { name = "mkdocs-autorefs" }, + { name = "platformdirs" }, + { name = "pymdown-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/24/a6/d544fae9749b19e23fb590f6344f9eae3a312323065070b4874236bb0e04/mkdocstrings-0.25.2.tar.gz", hash = "sha256:5cf57ad7f61e8be3111a2458b4e49c2029c9cb35525393b179f9c916ca8042dc", size = 91796 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2b/86/ee2aef075cc9a62a4f087c3c3f4e3e8a8318afe05a92f8f8415f1bf1af64/mkdocstrings-0.25.2-py3-none-any.whl", hash = "sha256:9e2cda5e2e12db8bb98d21e3410f3f27f8faab685a24b03b06ba7daa5b92abfc", size = 29289 }, +] + +[package.optional-dependencies] +python = [ + { name = "mkdocstrings-python" }, +] + +[[package]] +name = "mkdocstrings-python" +version = "1.10.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "griffe" }, + { name = "mkdocstrings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6b/ae/21b26c1fd62c8dd51ecefc8e848c14ca8f0dfdfeb903deeb20e86fb28ad1/mkdocstrings_python-1.10.8.tar.gz", hash = "sha256:5856a59cbebbb8deb133224a540de1ff60bded25e54d8beacc375bb133d39016", size = 161724 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c5/81/cda8afc58c7be82f0a7a9b54939e86f54eaad32a5b232afd6893ce3f00cb/mkdocstrings_python-1.10.8-py3-none-any.whl", hash = "sha256:bb12e76c8b071686617f824029cb1dfe0e9afe89f27fb3ad9a27f95f054dcd89", size = 108333 }, +] + [[package]] name = "more-itertools" version = "10.4.0" @@ -862,6 +952,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/5a/66/bbb1dd374f5c870f59c5bb1db0e18cbe7fa739415a24cbd95b2d1f5ae0c4/pyyaml_env_tag-0.1-py3-none-any.whl", hash = "sha256:af31106dec8a4d68c60207c1886031cbf839b68aa7abccdb19868200532c2069", size = 3911 }, ] +[[package]] +name = "rcslice" +version = "1.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/53/3e/abe47d91d5340b77b003baf96fdf8966c946eb4c5a704a844b5d03e6e578/rcslice-1.1.0.tar.gz", hash = "sha256:a2ce70a60690eb63e52b722e046b334c3aaec5e900b28578f529878782ee5c6e", size = 4414 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/aa/96/7935186fba032312eb8a75e6503440b0e6de76c901421f791408e4debd93/rcslice-1.1.0-py3-none-any.whl", hash = "sha256:1b12fc0c0ca452e8a9fd2b56ac008162f19e250906a4290a7e7a98be3200c2a6", size = 5180 }, +] + [[package]] name = "regex" version = "2024.7.24" From e2cf6a15620c887e9b536ad2aa5510c5b1a7c453 Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Sun, 9 Feb 2025 19:06:12 -0700 Subject: [PATCH 4/8] Fix timezone aware fields Signed-off-by: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> --- .gitignore | 7 +- pyproject.toml | 3 +- src/fastapi_async_sql/models.py | 8 +- src/fastapi_async_sql/typing.py | 28 +++++ tests/test_models.py | 181 ++++++++++++++++++++++++++++++++ uv.lock | 9 +- 6 files changed, 225 insertions(+), 11 deletions(-) create mode 100644 src/fastapi_async_sql/typing.py create mode 100644 tests/test_models.py diff --git a/.gitignore b/.gitignore index 809a958..fcd58d9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,4 @@ -.vscode/* -!.vscode/settings.json -!.vscode/tasks.json -!.vscode/launch.json -!.vscode/extensions.json -!.vscode/*.code-snippets +.vscode/ .history/ *.vsix .idea diff --git a/pyproject.toml b/pyproject.toml index 937fb4f..6a0876d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fastapi-async-sql" -version = "0.1.0-alpha.1" +version = "0.2.0-alpha.2" description = "Common utilities for Async SQL FastAPI applications" readme = "README.md" requires-python = ">=3.12" @@ -10,6 +10,7 @@ dependencies = [ "fastapi-pagination>=0.12.26,<0.13.0", "fastapi-filter[sqlalchemy]>=2.0.0,<3.0.0", "inflect>=7.3.1,<8.0.0", + "sqlalchemy[asyncio]>=2.0.32", ] [project.optional-dependencies] diff --git a/src/fastapi_async_sql/models.py b/src/fastapi_async_sql/models.py index 2434870..23ac378 100644 --- a/src/fastapi_async_sql/models.py +++ b/src/fastapi_async_sql/models.py @@ -6,10 +6,12 @@ from pydantic import AwareDatetime, ConfigDict from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.declarative import declared_attr -from sqlmodel import TIMESTAMP, Field, SQLModel +from sqlmodel import Field, SQLModel from fastapi_async_sql.utils.string import to_camel, to_snake_plural +from .typing import TimeStamp + class BaseSQLModel(AsyncAttrs, SQLModel): """Base SQL model with automatic __tablename__ generation. @@ -46,12 +48,12 @@ class BaseTimestampModel: created_at: AwareDatetime = Field( default_factory=lambda: datetime.now(tz=timezone.utc), - sa_type=TIMESTAMP(timezone=True), + sa_type=TimeStamp(timezone=True), ) updated_at: AwareDatetime | None = Field( default=None, - sa_type=TIMESTAMP(timezone=True), + sa_type=TimeStamp(timezone=True), sa_column_kwargs={"onupdate": lambda: datetime.now(tz=timezone.utc)}, ) diff --git a/src/fastapi_async_sql/typing.py b/src/fastapi_async_sql/typing.py new file mode 100644 index 0000000..0fa66db --- /dev/null +++ b/src/fastapi_async_sql/typing.py @@ -0,0 +1,28 @@ +"""This module defines custom types for the FastAPI Async SQL package.""" + +from datetime import datetime, timezone + +import sqlalchemy as sa + + +class TimeStamp(sa.types.TypeDecorator): + impl = sa.types.DateTime + LOCAL_TIMEZONE = timezone.utc + + def process_bind_param(self, value: datetime | None, dialect): + """Convert datetime to UTC timezone.""" + if value is None: + return None + if value.tzinfo is None: + value = value.astimezone(self.LOCAL_TIMEZONE) + + return value.astimezone(timezone.utc) + + def process_result_value(self, value, dialect): + """Convert datetime to UTC timezone.""" + if value is None: + return None + if value.tzinfo is None: + return value.replace(tzinfo=timezone.utc) + + return value.astimezone(timezone.utc) diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..d660137 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,181 @@ +"""Tests for models.""" + +from datetime import timezone +from uuid import uuid4 + +import pytest +import sqlalchemy as sa +from sqlmodel import Session + +from .models.hero_model import Hero +from .models.item_model import Item +from .models.team_model import Team + + +async def test_create_hero(db: Session, team: Team, item: Item): + """Test creating a hero.""" + hero = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team=team, + item=item, + ) + db.add(hero) + await db.commit() + await db.refresh(hero) + assert hero.id is not None + assert hero.created_at is not None + assert hero.updated_at is None + + +async def test_create_team(db: Session): + """Test creating a team.""" + team = Team(name="Test Team", headquarters="Test HQ") + db.add(team) + await db.commit() + await db.refresh(team) + assert team.id is not None + assert team.created_at is not None + assert team.updated_at is None + + +async def test_create_item(db: Session): + """Test creating an item.""" + item = Item(name="Test Item", created_by_id=uuid4()) + db.add(item) + await db.commit() + await db.refresh(item) + assert item.id is not None + assert item.created_at is not None + assert item.updated_at is None + + +async def test_hero_relationships(db: Session): + """Test hero relationships.""" + team = Team(name="Test Team", headquarters="Test HQ") + item = Item(name="Test Item", created_by_id=uuid4()) + hero = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team=team, + item=item, + ) + db.add(hero) + await db.commit() + await db.refresh(hero) + assert hero.awaitable_attrs.team is not None + assert hero.awaitable_attrs.item is not None + + +async def test_timestamp_update(db: Session, team: Team, item: Item): + """Test timestamp update.""" + hero = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team=team, + item=item, + ) + db.add(hero) + await db.commit() + await db.refresh(hero) + hero.name = "Updated Hero" + await db.commit() + await db.refresh(hero) + assert hero.updated_at is not None + + +async def test_unique_constraint(db: Session, team: Team, item: Item): + """Test unique constraint violation.""" + hero1 = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team=team, + item=item, + ) + hero2 = Hero( + name="Test Hero", + secret_identity="Another Identity", # nosec:B106 + age=25, + team=team, + item=item, + ) + db.add(hero1) + await db.commit() + db.add(hero2) + with pytest.raises(sa.exc.IntegrityError): + await db.commit() + + +async def test_foreign_key_constraint(db: Session): + """Test foreign key constraint violation.""" + hero = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team_id=uuid4(), + ) + db.add(hero) + with pytest.raises(sa.exc.IntegrityError): + await db.commit() + + +async def test_invalid_uuid(db: Session): + """Test invalid UUID.""" + with pytest.raises(ValueError): + Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team_id="invalid-uuid", + ) + + +async def test_implicit_io(db: Session, team: Team, item: Item): + """Test implicit I/O prevention.""" + hero = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team=team, + item=item, + ) + db.add(hero) + await db.commit() + await db.refresh(hero) + with pytest.raises(sa.exc.MissingGreenlet): + _ = hero.team.name # Accessing relationship without await + + +async def test_aware_datetime(db: Session, team: Team, item: Item): + """Test aware datetime fields.""" + hero = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team=team, + item=item, + ) + db.add(hero) + await db.commit() + await db.refresh(hero) + assert hero.created_at.tzinfo is not None + assert hero.created_at.tzinfo == timezone.utc + + +async def test_model_config(db: Session, team: Team, item: Item): + """Test model configuration.""" + hero = Hero( + name="Test Hero", + secret_identity="Test Identity", # nosec:B106 + age=30, + team=team, + item=item, + ) + db.add(hero) + await db.commit() + await db.refresh(hero) + assert hero.model_dump(by_alias=True)["secretIdentity"] == "Test Identity" diff --git a/uv.lock b/uv.lock index ef25005..37d6b7d 100644 --- a/uv.lock +++ b/uv.lock @@ -228,13 +228,14 @@ wheels = [ [[package]] name = "fastapi-async-sql" -version = "0.1.0a1" +version = "0.2.0a1" source = { editable = "." } dependencies = [ { name = "fastapi" }, { name = "fastapi-filter", extra = ["sqlalchemy"] }, { name = "fastapi-pagination" }, { name = "inflect" }, + { name = "sqlalchemy", extra = ["asyncio"] }, { name = "sqlmodel" }, ] @@ -281,6 +282,7 @@ requires-dist = [ { name = "pytest-mock", marker = "extra == 'test'", specifier = ">=3.14.0" }, { name = "ruff", marker = "extra == 'lint'", specifier = ">=0.6.2" }, { name = "ruff-lsp", marker = "extra == 'lint'", specifier = ">=0.0.55" }, + { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.32" }, { name = "sqlmodel", specifier = ">=0.0.21,<0.1.0" }, { name = "termynal", marker = "extra == 'docs'", specifier = ">=0.12.1" }, ] @@ -1088,6 +1090,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/1b/045185a9f6481d926a451aafaa0d07c98f19ac7abe730dff9630c9ead4fa/SQLAlchemy-2.0.32-py3-none-any.whl", hash = "sha256:e567a8793a692451f706b363ccf3c45e056b67d90ead58c3bc9471af5d212202", size = 1878765 }, ] +[package.optional-dependencies] +asyncio = [ + { name = "greenlet" }, +] + [[package]] name = "sqlmodel" version = "0.0.21" From d303346afa1157ab319a0d7a3b113fc25345c0aa Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Wed, 12 Feb 2025 23:12:59 -0700 Subject: [PATCH 5/8] fix: maintain db session active throughout request lifecycle - Add expire_on_commit=False to session config - Replace commit/refresh calls with flush in repositories - Wrap request handling in transaction context - Add test for multiple db operations in single request Bumps version to 0.1.0-alpha.2 --- Makefile | 6 ++-- pyproject.toml | 7 ++-- src/fastapi_async_sql/middlewares.py | 20 ++++++++--- src/fastapi_async_sql/repositories.py | 15 ++++---- tests/test_middlewares.py | 50 +++++++++++++++++++++++++-- uv.lock | 46 ++++++++++++++---------- 6 files changed, 106 insertions(+), 38 deletions(-) diff --git a/Makefile b/Makefile index 791ed26..3e596eb 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,12 @@ .PHONY: migrations lint: - uv run pre-commit install - uv run pre-commit run -a -v + uvx pre-commit install + uvx pre-commit run -a -v update: uv lock --upgrade - uv run pre-commit autoupdate -j 10 + uvx pre-commit autoupdate -j 10 sync: uv sync --all-extras diff --git a/pyproject.toml b/pyproject.toml index 6a0876d..f8179ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fastapi-async-sql" -version = "0.2.0-alpha.2" +version = "0.2.0-alpha.3" description = "Common utilities for Async SQL FastAPI applications" readme = "README.md" requires-python = ">=3.12" @@ -13,7 +13,7 @@ dependencies = [ "sqlalchemy[asyncio]>=2.0.32", ] -[project.optional-dependencies] +[dependency-groups] lint = [ "ruff>=0.6.2", "ruff-lsp>=0.0.55", @@ -36,6 +36,9 @@ docs = [ "mkdocstrings[python]>=0.25.2", ] +[tool.uv] +default-groups = ["lint", "test", "docs"] + [build-system] requires = ["hatchling"] build-backend = "hatchling.build" diff --git a/src/fastapi_async_sql/middlewares.py b/src/fastapi_async_sql/middlewares.py index ba66ee6..495db62 100644 --- a/src/fastapi_async_sql/middlewares.py +++ b/src/fastapi_async_sql/middlewares.py @@ -47,15 +47,27 @@ def __init__( else: self.engine = custom_engine + # Modify session defaults to keep session active + default_session_options = { + "expire_on_commit": False, # Prevent expiring objects after commit + "autoflush": True, + } + if session_options: + default_session_options.update(session_options) + self.async_session = async_sessionmaker( bind=self.engine, class_=AsyncSession, - **session_options or {}, + **default_session_options, ) async def dispatch(self, request: Request, call_next: RequestResponseEndpoint): """Method to dispatch the request.""" async with self.async_session() as session: - request.state.db = session - response = await call_next(request) - return response + # Begin transaction + async with session.begin(): + request.state.db = session + response = await call_next(request) + # Transaction will be committed automatically when exiting context + # Only if no exceptions occurred + return response diff --git a/src/fastapi_async_sql/repositories.py b/src/fastapi_async_sql/repositories.py index cc13351..24cadb1 100644 --- a/src/fastapi_async_sql/repositories.py +++ b/src/fastapi_async_sql/repositories.py @@ -122,13 +122,12 @@ async def create( try: session.add(db_obj) - await session.commit() + # Don't commit here - let the middleware handle the transaction + await session.flush() # Just flush to get generated values except exc.IntegrityError as err: - await session.rollback() raise CreateObjectError( obj=self.model.__name__, **db_obj.model_dump() ) from err - await session.refresh(db_obj) return db_obj async def update( @@ -146,14 +145,13 @@ async def update( else: update_data = obj_new.model_dump( exclude_unset=True, - exclude_defaults=True, - ) # This tells Pydantic to not include the values that were not sent + ) for field in update_data: setattr(obj_current, field, update_data[field]) session.add(obj_current) - await session.commit() - await session.refresh(obj_current) + # Don't commit here - let the middleware handle the transaction + await session.flush() return obj_current async def remove( @@ -165,7 +163,8 @@ async def remove( if obj is None: raise ObjectNotFoundError(obj=self.model.__name__, id=id) await session.delete(obj) - await session.commit() + # Don't commit here - let the middleware handle the transaction + await session.flush() return None # noinspection PyMethodMayBeStatic diff --git a/tests/test_middlewares.py b/tests/test_middlewares.py index 7a7754d..6a81a50 100644 --- a/tests/test_middlewares.py +++ b/tests/test_middlewares.py @@ -154,8 +154,7 @@ async def create_hero(request: Request): item_id=item.id, ) request.state.db.add(hero) - await request.state.db.commit() - await request.state.db.refresh(hero) + await request.state.db.flush() return hero response = await client.post("/heroes") @@ -166,3 +165,50 @@ async def create_hero(request: Request): assert assert_hero.name == "Batman" assert assert_hero.secret_identity == "Bruce Wayne" # nosec: B105 assert assert_hero.age == 40 + + +async def test_async_sqlalchemy_middleware_multiple_db_operations( + app_with_db_middleware: FastAPI, + client: httpx.AsyncClient, + db: AsyncSession, + team: Team, + item: Item, +): + """Test that the middleware correctly handles multiple database operations in a single request.""" + + @app_with_db_middleware.post( + "/heroes/create-and-update", + response_model=HeroReadSchema, + status_code=status.HTTP_201_CREATED, + ) + async def create_and_update_hero(request: Request): + # First operation: Create hero + hero = Hero( + name="Superman", + secret_identity="Clark Kent", # nosec: B106 + age=35, + team_id=team.id, + item_id=item.id, + ) + request.state.db.add(hero) + await request.state.db.flush() + + # Second operation: Update the hero's name + hero.name = "Superman Updated" + request.state.db.add(hero) + await request.state.db.flush() + + return hero + + response = await client.post("/heroes/create-and-update") + assert response.status_code == status.HTTP_201_CREATED + hero_data = response.json() + + # Verify the hero was created and updated correctly + assert_hero = await db.get(Hero, UUID(hero_data["id"])) + assert assert_hero is not None + assert assert_hero.name == "Superman Updated" # Verify the update worked + assert assert_hero.secret_identity == "Clark Kent" # nosec: B105 + assert assert_hero.age == 35 + assert assert_hero.team_id == team.id + assert assert_hero.item_id == item.id diff --git a/uv.lock b/uv.lock index 37d6b7d..357a3e6 100644 --- a/uv.lock +++ b/uv.lock @@ -116,7 +116,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -228,7 +228,7 @@ wheels = [ [[package]] name = "fastapi-async-sql" -version = "0.2.0a1" +version = "0.2.0a3" source = { editable = "." } dependencies = [ { name = "fastapi" }, @@ -239,7 +239,7 @@ dependencies = [ { name = "sqlmodel" }, ] -[package.optional-dependencies] +[package.dev-dependencies] docs = [ { name = "mdx-include" }, { name = "mkdocs-git-revision-date-localized-plugin" }, @@ -264,27 +264,35 @@ test = [ [package.metadata] requires-dist = [ - { name = "aiosqlite", marker = "extra == 'test'", specifier = ">=0.20.0" }, - { name = "factory-boy", marker = "extra == 'test'", specifier = ">=3.3.1" }, { name = "fastapi", specifier = ">=0.100.0,<0.113.0" }, { name = "fastapi-filter", extras = ["sqlalchemy"], specifier = ">=2.0.0,<3.0.0" }, { name = "fastapi-pagination", specifier = ">=0.12.26,<0.13.0" }, - { name = "httpx", marker = "extra == 'test'", specifier = ">=0.27.2" }, { name = "inflect", specifier = ">=7.3.1,<8.0.0" }, - { name = "mdx-include", marker = "extra == 'docs'", specifier = ">=1.4.2" }, - { name = "mkdocs-git-revision-date-localized-plugin", marker = "extra == 'docs'", specifier = ">=1.2.7" }, - { name = "mkdocs-material", marker = "extra == 'docs'", specifier = ">=9.5.33" }, - { name = "mkdocstrings", extras = ["python"], marker = "extra == 'docs'", specifier = ">=0.25.2" }, - { name = "pre-commit", marker = "extra == 'lint'", specifier = ">=3.8.0" }, - { name = "pytest", marker = "extra == 'test'", specifier = ">=8.3.2" }, - { name = "pytest-asyncio", marker = "extra == 'test'", specifier = ">=0.24.0" }, - { name = "pytest-cov", marker = "extra == 'test'", specifier = ">=5.0.0" }, - { name = "pytest-mock", marker = "extra == 'test'", specifier = ">=3.14.0" }, - { name = "ruff", marker = "extra == 'lint'", specifier = ">=0.6.2" }, - { name = "ruff-lsp", marker = "extra == 'lint'", specifier = ">=0.0.55" }, { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0.32" }, { name = "sqlmodel", specifier = ">=0.0.21,<0.1.0" }, - { name = "termynal", marker = "extra == 'docs'", specifier = ">=0.12.1" }, +] + +[package.metadata.requires-dev] +docs = [ + { name = "mdx-include", specifier = ">=1.4.2" }, + { name = "mkdocs-git-revision-date-localized-plugin", specifier = ">=1.2.7" }, + { name = "mkdocs-material", specifier = ">=9.5.33" }, + { name = "mkdocstrings", extras = ["python"], specifier = ">=0.25.2" }, + { name = "termynal", specifier = ">=0.12.1" }, +] +lint = [ + { name = "pre-commit", specifier = ">=3.8.0" }, + { name = "ruff", specifier = ">=0.6.2" }, + { name = "ruff-lsp", specifier = ">=0.0.55" }, +] +test = [ + { name = "aiosqlite", specifier = ">=0.20.0" }, + { name = "factory-boy", specifier = ">=3.3.1" }, + { name = "httpx", specifier = ">=0.27.2" }, + { name = "pytest", specifier = ">=8.3.2" }, + { name = "pytest-asyncio", specifier = ">=0.24.0" }, + { name = "pytest-cov", specifier = ">=5.0.0" }, + { name = "pytest-mock", specifier = ">=3.14.0" }, ] [[package]] @@ -551,7 +559,7 @@ version = "1.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, From 563a944cfcabc0093321809b6f652fd52eafdf2e Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Thu, 20 Feb 2025 17:18:57 -0700 Subject: [PATCH 6/8] feat: Add support for intercepted filtering in repositories - Enhance BaseRepository filter method to handle None filter_by more robustly - Add test case demonstrating custom filter with additional constraints - Bump version to 0.2.0-alpha.4 --- pyproject.toml | 2 +- src/fastapi_async_sql/repositories.py | 2 +- tests/test_api.py | 61 +++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f8179ed..3f112e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fastapi-async-sql" -version = "0.2.0-alpha.3" +version = "0.2.0-alpha.4" description = "Common utilities for Async SQL FastAPI applications" readme = "README.md" requires-python = ">=3.12" diff --git a/src/fastapi_async_sql/repositories.py b/src/fastapi_async_sql/repositories.py index 24cadb1..1b74ff1 100644 --- a/src/fastapi_async_sql/repositories.py +++ b/src/fastapi_async_sql/repositories.py @@ -179,7 +179,7 @@ def _apply_query_filter( self, query: T | Select[T], filter_by: Filter | None ) -> T | Select[T]: """Get the query with the filter applied.""" - if filter_by: + if filter_by is not None: query = filter_by.filter(query) if getattr(filter_by, filter_by.Constants.ordering_field_name, None): query = filter_by.sort(query) diff --git a/tests/test_api.py b/tests/test_api.py index 14da301..b10cacd 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -404,3 +404,64 @@ async def get_heroes( "createdById": str(hero.item.created_by_id), }, } + + +async def test_api_filtering_with_intercepted_filter( + app: FastAPI, + client: httpx.AsyncClient, + marvel_heroes, +): + """Test API filtering with intercepted filter to add additional constraints.""" + + class HeroFilter(Filter): + name: str | None = None + age__gt: int | None = None + order_by: list[str] | None = Field(default_factory=list) + + class Constants(Filter.Constants): + model = Hero + + def filter(self, query): + # First apply the original filters + query = super().filter(query) + # Then add our additional filter + return query.where(Hero.name != "Vision") + + @app.get("/heroes/filtered") + async def get_filtered_heroes( + repository: AnnotatedRepositoryHero, + filter_by: HeroFilter = FilterDepends(HeroFilter), + ) -> list[HeroReadSchema]: + return await repository.get_multi(filter_by=filter_by) + + # Test with no filters - should return all heroes except Vision + response = await client.get("/heroes/filtered") + assert response.status_code == status.HTTP_200_OK + response_data = response.json() + response_names = [hero["name"] for hero in response_data] + assert "Vision" not in response_names + assert len(response_names) == len(marvel_heroes) - 1 + + # Test with age filter - should still exclude Vision + response = await client.get("/heroes/filtered?age__gt=40") + assert response.status_code == status.HTTP_200_OK + response_data = response.json() + response_names = [hero["name"] for hero in response_data] + assert "Vision" not in response_names + assert all(hero["age"] > 40 for hero in response_data) + + # Test with name filter - should still exclude Vision + response = await client.get("/heroes/filtered?name=Thor") + assert response.status_code == status.HTTP_200_OK + response_data = response.json() + response_names = [hero["name"] for hero in response_data] + assert "Vision" not in response_names + assert all(hero["name"] == "Thor" for hero in response_data) + + # Test with ordering - should return ordered list without Vision + response = await client.get("/heroes/filtered?order_by=name") + assert response.status_code == status.HTTP_200_OK + response_data = response.json() + response_names = [hero["name"] for hero in response_data] + assert "Vision" not in response_names + assert response_names == sorted(response_names) From 0037b656c00b436feca860a4ac949785cb2b6d22 Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Thu, 10 Apr 2025 22:49:54 -0600 Subject: [PATCH 7/8] feat: add PydanticJSONB typing for postgres JSONB Bumps version to 0.1.0-alpha.5 --- pyproject.toml | 2 +- src/fastapi_async_sql/typing.py | 87 +++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 3f112e0..6d2f6e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fastapi-async-sql" -version = "0.2.0-alpha.4" +version = "0.2.0-alpha.5" description = "Common utilities for Async SQL FastAPI applications" readme = "README.md" requires-python = ">=3.12" diff --git a/src/fastapi_async_sql/typing.py b/src/fastapi_async_sql/typing.py index 0fa66db..b60a164 100644 --- a/src/fastapi_async_sql/typing.py +++ b/src/fastapi_async_sql/typing.py @@ -1,8 +1,13 @@ """This module defines custom types for the FastAPI Async SQL package.""" from datetime import datetime, timezone +from typing import Any, TypeVar, Union, get_args, get_origin import sqlalchemy as sa +from pydantic import BaseModel +from pydantic_core import to_jsonable_python +from sqlalchemy import types +from sqlalchemy.dialects.postgresql import JSONB class TimeStamp(sa.types.TypeDecorator): @@ -26,3 +31,85 @@ def process_result_value(self, value, dialect): return value.replace(tzinfo=timezone.utc) return value.astimezone(timezone.utc) + + +BaseModelType = TypeVar("BaseModelType", bound=BaseModel) + +# Define a type alias for JSON-serializable values +JSONValue = Union[dict[str, Any], list[Any], str, int, float, bool, None] + + +class PydanticJSONB(types.TypeDecorator): # type: ignore + """Custom type to automatically handle Pydantic model serialization. + + Until this PR is merged: https://github.com/fastapi/sqlmodel/pull/1324 + """ + + impl = JSONB # use JSONB type in Postgres (fallback to JSON for others) + cache_ok = True # allow SQLAlchemy to cache results + + def __init__( + self, + model_class: ( + type[BaseModelType] + | type[list[BaseModelType]] + | type[dict[str, BaseModelType]] + ), + *args: Any, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + self.model_class = model_class # Pydantic model class to use + + def process_bind_param(self, value: Any, dialect: Any) -> JSONValue: # noqa: ANN401, ARG002, ANN001 + """Convert Pydantic model to JSON-serializable value.""" + if value is None: + return None + if isinstance(value, BaseModel): + return value.model_dump(mode="json") + if isinstance(value, list): + return [ + m.model_dump(mode="json") + if isinstance(m, BaseModel) + else to_jsonable_python(m) + for m in value + ] + if isinstance(value, dict): + return { + k: v.model_dump(mode="json") + if isinstance(v, BaseModel) + else to_jsonable_python(v) + for k, v in value.items() + } + + # We know to_jsonable_python returns a JSON-serializable value, but mypy sees it as Any + return to_jsonable_python(value) # type: ignore[no-any-return] + + def process_result_value( + self, value: Any, dialect: Any + ) -> BaseModelType | list[BaseModelType] | dict[str, BaseModelType] | None: # noqa: ANN401, ARG002, ANN001 + """Convert JSONB value from database to Pydantic model.""" + if value is None: + return None + if isinstance(value, dict): + # If model_class is a Dict type hint, handle key-value pairs + origin = get_origin(self.model_class) + if origin is dict: + model_class = get_args(self.model_class)[ + 1 + ] # Get the value type (the model) + return {k: model_class.model_validate(v) for k, v in value.items()} + # Regular case: the whole dict represents a single model + return self.model_class.model_validate(value) # type: ignore + if isinstance(value, list): + # If model_class is a List type hint + origin = get_origin(self.model_class) + if origin is list: + model_class = get_args(self.model_class)[0] + return [model_class.model_validate(v) for v in value] + # Fallback case (though this shouldn't happen given our __init__ types) + return [self.model_class.model_validate(v) for v in value] # type: ignore + + raise TypeError( + f"Unsupported type for PydanticJSONB from database: {type(value)}. Expected a dictionary or list." + ) From 54ca8793f7faf52ec7f36b106364381669875700 Mon Sep 17 00:00:00 2001 From: Patrick Rodrigues <23041890+pythrick@users.noreply.github.com> Date: Thu, 10 Apr 2025 22:50:51 -0600 Subject: [PATCH 8/8] fix: fix BaseRepository.create method to handle model instances as obj_in parameter Bumps version to 0.1.0-alpha.5 --- pyproject.toml | 2 +- src/fastapi_async_sql/repositories.py | 5 ++++- tests/test_repositories.py | 26 ++++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6d2f6e1..0280836 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "fastapi-async-sql" -version = "0.2.0-alpha.5" +version = "0.2.0-alpha.6" description = "Common utilities for Async SQL FastAPI applications" readme = "README.md" requires-python = ">=3.12" diff --git a/src/fastapi_async_sql/repositories.py b/src/fastapi_async_sql/repositories.py index 1b74ff1..aef429f 100644 --- a/src/fastapi_async_sql/repositories.py +++ b/src/fastapi_async_sql/repositories.py @@ -118,7 +118,10 @@ async def create( ) -> ModelType: """Create a new object.""" session = self._get_db_session(db_session) - db_obj = self.model.model_validate(obj_in, update=extra_data) + if isinstance(obj_in, self.model): + db_obj = obj_in.model_copy(update=extra_data) + else: + db_obj = self.model.model_validate(obj_in, update=extra_data) try: session.add(db_obj) diff --git a/tests/test_repositories.py b/tests/test_repositories.py index b1c9ee3..5028998 100644 --- a/tests/test_repositories.py +++ b/tests/test_repositories.py @@ -72,6 +72,32 @@ async def test_create_hero_duplicate( await hero_repository.create(obj_in=hero_data) +async def test_create_hero_from_model_instance( + db: AsyncSession, + team: Team, + item: Item, + hero_repository: HeroRepository, +): + """Test create hero from a model instance (not schema).""" + hero_instance = Hero( + name="Instance Hero", + age=25, + secret_identity="Instance Identity", # nosec: B106 + team_id=team.id, + item_id=item.id, + ) + + created_hero = await hero_repository.create(obj_in=hero_instance) + + assert created_hero.id is not None + assert created_hero.name == "Instance Hero" + assert created_hero.age == 25 + assert created_hero.secret_identity == "Instance Identity" # nosec: B105 + assert created_hero.team_id == team.id + assert created_hero.item_id == item.id + assert isinstance(created_hero.created_at, datetime) + + async def test_get_hero( db: AsyncSession, heroes: list[Hero],