diff --git a/src/starlite_saqlalchemy/worker.py b/src/starlite_saqlalchemy/worker.py index a0d37997..4ca68089 100644 --- a/src/starlite_saqlalchemy/worker.py +++ b/src/starlite_saqlalchemy/worker.py @@ -2,12 +2,14 @@ from __future__ import annotations import asyncio +from functools import partial from typing import TYPE_CHECKING, Any import msgspec import saq +from starlite.utils.serialization import default_serializer -from starlite_saqlalchemy import redis, settings +from starlite_saqlalchemy import redis, settings, type_encoders if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Collection @@ -20,6 +22,10 @@ "queue", ] +encoder = msgspec.json.Encoder( + enc_hook=partial(default_serializer, type_encoders=type_encoders.type_encoders_map) +) + class Queue(saq.Queue): """Async task queue.""" @@ -36,7 +42,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: **kwargs: Passed through to `saq.Queue.__init__()` """ kwargs.setdefault("name", settings.app.slug) - kwargs.setdefault("dump", msgspec.json.encode) + kwargs.setdefault("dump", encoder.encode) kwargs.setdefault("load", msgspec.json.decode) super().__init__(*args, **kwargs) diff --git a/tests/conftest.py b/tests/conftest.py index 67398654..fa2db5ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,6 +7,7 @@ from uuid import uuid4 import pytest +from asyncpg.pgproto import pgproto from starlite import Starlite from structlog.contextvars import clear_contextvars from structlog.testing import CapturingLogger @@ -82,7 +83,11 @@ def fx_raw_authors() -> list[dict[str, Any]]: @pytest.fixture(name="authors") def fx_authors(raw_authors: list[dict[str, Any]]) -> list[authors.Author]: """Collection of parsed Author models.""" - return [authors.ReadDTO(**raw).to_mapped() for raw in raw_authors] + mapped_authors = [authors.ReadDTO(**raw).to_mapped() for raw in raw_authors] + # convert these to pgproto UUIDs as that is what we get back from sqlalchemy + for author in mapped_authors: + author.id = pgproto.UUID(str(author.id)) + return mapped_authors @pytest.fixture(name="raw_books") @@ -103,7 +108,11 @@ def fx_raw_books(raw_authors: list[dict[str, Any]]) -> list[dict[str, Any]]: @pytest.fixture(name="books") def fx_books(raw_books: list[dict[str, Any]]) -> list[books.Book]: """Collection of parsed Book models.""" - return [books.ReadDTO(**raw).to_mapped() for raw in raw_books] + mapped_books = [books.ReadDTO(**raw).to_mapped() for raw in raw_books] + # convert these to pgproto UUIDs as that is what we get back from sqlalchemy + for book in mapped_books: + book.id = pgproto.UUID(str(book.id)) + return mapped_books @pytest.fixture(name="create_module") diff --git a/tests/unit/test_worker.py b/tests/unit/test_worker.py new file mode 100644 index 00000000..304e9251 --- /dev/null +++ b/tests/unit/test_worker.py @@ -0,0 +1,24 @@ +"""Tests for the SAQ async worker functionality.""" +from __future__ import annotations + +from asyncpg.pgproto import pgproto + +from starlite_saqlalchemy import worker +from tests.utils.domain.authors import Author, ReadDTO + + +def test_worker_decoder_handles_pgproto_uuid() -> None: + """Test that the decoder can handle pgproto.UUID instances.""" + pg_uuid = pgproto.UUID("0448bde2-7c69-4e6b-9c03-7b217e3b563d") + encoded = worker.encoder.encode(pg_uuid) + assert encoded == b'"0448bde2-7c69-4e6b-9c03-7b217e3b563d"' + + +def test_worker_decoder_handles_pydantic_models(authors: list[Author]) -> None: + """Test that the decoder we use for SAQ will encode a pydantic model.""" + pydantic_model = ReadDTO.from_orm(authors[0]) + encoded = worker.encoder.encode(pydantic_model) + assert ( + encoded + == b'{"id":"97108ac1-ffcb-411d-8b1e-d9183399f63b","created":"0001-01-01T00:00:00","updated":"0001-01-01T00:00:00","name":"Agatha Christie","dob":"1890-09-15"}' + )