diff --git a/README.md b/README.md index 3e09f74..4d03b4c 100644 --- a/README.md +++ b/README.md @@ -17,3 +17,4 @@ Reference https://github.com/cosmicpython/code - [✅ CHAPTER_01 domain model](https://github.com/sawaca96/architecture-patterns-with-python/commit/fb604a0bc25b70a98e16dc4185eb8c9eb96ceb3d) - [✅ CHAPTER_02 repository](https://github.com/sawaca96/architecture-patterns-with-python/commit/fb604a0bc25b70a98e16dc4185eb8c9eb96ceb3d) - [4️⃣ CHAPTER_04 usecase](https://github.com/sawaca96/architecture-patterns-with-python/pull/1) +- [6️⃣ CHAPTER_06 unit of work](https://github.com/sawaca96/architecture-patterns-with-python/pull/2#pullrequestreview-1265028411) diff --git a/app/allocation/__init__.py b/app/allocation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/allocation/adapters/__init__.py b/app/allocation/adapters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/db.py b/app/allocation/adapters/db.py similarity index 96% rename from app/db.py rename to app/allocation/adapters/db.py index 36ae1f6..7cb2c61 100644 --- a/app/db.py +++ b/app/allocation/adapters/db.py @@ -12,7 +12,6 @@ def __init__(self, url: str) -> None: self._session_factory = async_scoped_session( sessionmaker( autocommit=False, - autoflush=False, class_=AsyncSession, bind=self._engine, ), diff --git a/app/orm.py b/app/allocation/adapters/orm.py similarity index 94% rename from app/orm.py rename to app/allocation/adapters/orm.py index e54cd86..ceeddd7 100644 --- a/app/orm.py +++ b/app/allocation/adapters/orm.py @@ -2,7 +2,7 @@ from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import registry, relationship -from app import models +from app.allocation.domain import models mapper_registry = registry() metadata = mapper_registry.metadata @@ -20,7 +20,7 @@ metadata, sa.Column("id", UUID(as_uuid=True), primary_key=True), sa.Column("sku", sa.String), - sa.Column("purchased_quantity", sa.Integer), + sa.Column("qty", sa.Integer), sa.Column("eta", sa.Date, nullable=True), ) diff --git a/app/repository.py b/app/allocation/adapters/repository.py similarity index 60% rename from app/repository.py rename to app/allocation/adapters/repository.py index ad99894..47cb49a 100644 --- a/app/repository.py +++ b/app/allocation/adapters/repository.py @@ -3,12 +3,12 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import subqueryload +from sqlalchemy.orm import selectinload, subqueryload -from app import models +from app.allocation.domain import models -class BatchAbstractRepository(abc.ABC): +class AbstractBatchRepository(abc.ABC): @abc.abstractmethod async def add(self, batch: models.Batch) -> None: raise NotImplementedError @@ -22,7 +22,7 @@ async def list(self) -> list[models.Batch]: raise NotImplementedError -class PGBatchRepository(BatchAbstractRepository): +class PGBatchRepository(AbstractBatchRepository): def __init__(self, session: AsyncSession) -> None: self._session = session @@ -31,7 +31,14 @@ async def add(self, batch: models.Batch) -> None: await self._session.flush() async def get(self, id: UUID) -> models.Batch: - return await self._session.get(models.Batch, id) + # async sqlalchemy doesn't support relationship + # It raise 'greenlet_spawn has not been called; can't call await_() here. Was IO attempted in an unexpected place?' # noqa E501 + result = await self._session.execute( + sa.select(models.Batch) + .where(models.Batch.id == id) + .options(selectinload(models.Batch.allocations)) + ) + return result.scalar_one() async def list(self) -> list[models.Batch]: result = await self._session.execute( diff --git a/app/allocation/domain/__init__.py b/app/allocation/domain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/models.py b/app/allocation/domain/models.py similarity index 75% rename from app/models.py rename to app/allocation/domain/models.py index 2d70359..2759634 100644 --- a/app/models.py +++ b/app/allocation/domain/models.py @@ -1,8 +1,8 @@ from __future__ import annotations -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import date -from uuid import UUID +from uuid import UUID, uuid4 class OutOfStock(Exception): @@ -18,21 +18,20 @@ def allocate(line: OrderLine, batches: list[Batch]) -> UUID: raise OutOfStock(f"Out of stock for sku {line.sku}") -@dataclass(unsafe_hash=True, kw_only=True) # TODO: kw_only를 언제 써야 할까? +@dataclass(unsafe_hash=True, kw_only=True) class OrderLine: - id: UUID + id: UUID = field(default_factory=uuid4) sku: str qty: int +@dataclass(kw_only=True) class Batch: - def __init__(self, id: UUID, sku: str, qty: int, eta: date | None) -> None: - # TODO: id 값 업으면 기본 값 채우기 - self.id = id - self.sku = sku - self.eta = eta - self.purchased_quantity = qty - self.allocations: set[OrderLine] = set() + id: UUID = field(default_factory=uuid4) + sku: str + eta: date = None + qty: int + allocations: set[OrderLine] = field(default_factory=lambda: set()) def __repr__(self) -> str: return f"" @@ -66,7 +65,7 @@ def allocated_quantity(self) -> int: @property def available_quantity(self) -> int: - return self.purchased_quantity - self.allocated_quantity + return self.qty - self.allocated_quantity def can_allocate(self, line: OrderLine) -> bool: return ( diff --git a/app/allocation/routers/__init__.py b/app/allocation/routers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/dependencies.py b/app/allocation/routers/dependencies.py similarity index 59% rename from app/dependencies.py rename to app/allocation/routers/dependencies.py index 41b4b58..553df83 100644 --- a/app/dependencies.py +++ b/app/allocation/routers/dependencies.py @@ -4,9 +4,10 @@ from fastapi import Depends from sqlalchemy.ext.asyncio import AsyncSession +from app.allocation.adapters.db import DB +from app.allocation.adapters.repository import AbstractBatchRepository, PGBatchRepository +from app.allocation.service_layer.unit_of_work import AbstractUnitOfWork, BatchUnitOfWork from app.config import get_config -from app.db import DB -from app.repository import BatchAbstractRepository, PGBatchRepository config = get_config() @@ -23,5 +24,9 @@ async def session(db: DB = Depends(db)) -> AsyncGenerator[AsyncSession, None]: def repository( session: AsyncSession = Depends(session), -) -> BatchAbstractRepository: +) -> AbstractBatchRepository: return PGBatchRepository(session) + + +def batch_uow() -> AbstractUnitOfWork[AbstractBatchRepository]: + return BatchUnitOfWork() diff --git a/app/allocation/routers/main.py b/app/allocation/routers/main.py new file mode 100644 index 0000000..132be4e --- /dev/null +++ b/app/allocation/routers/main.py @@ -0,0 +1,44 @@ +from datetime import date +from uuid import UUID + +from fastapi import Body, Depends, FastAPI, HTTPException + +from app.allocation.adapters.repository import AbstractBatchRepository +from app.allocation.domain import models +from app.allocation.routers.dependencies import batch_uow +from app.allocation.service_layer import services +from app.allocation.service_layer.unit_of_work import AbstractUnitOfWork + +app = FastAPI() +# start_mappers() # TODO: 운영환경에서는 실행되어야 함 + + +@app.get("/") +async def root() -> dict[str, str]: + return {"message": "Hello World"} + + +@app.post("/batches", status_code=201) +async def add_batch( + batch_id: UUID = Body(), + sku: str = Body(), + quantity: int = Body(), + eta: date = Body(default=None), + uow: AbstractUnitOfWork[AbstractBatchRepository] = Depends(batch_uow), +) -> dict[str, str]: + await services.add_batch(batch_id, sku, quantity, eta, uow) + return {"message": "success"} + + +@app.post("/allocate", response_model=dict[str, str], status_code=201) +async def allocate( + line_id: UUID = Body(), + sku: str = Body(), + quantity: int = Body(), + uow: AbstractUnitOfWork[AbstractBatchRepository] = Depends(batch_uow), +) -> dict[str, str]: + try: + batch_id = await services.allocate(line_id, sku, quantity, uow) + except (models.OutOfStock, services.InvalidSku) as e: + raise HTTPException(status_code=400, detail=str(e)) + return {"batch_id": str(batch_id)} diff --git a/app/allocation/service_layer/__init__.py b/app/allocation/service_layer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/allocation/service_layer/services.py b/app/allocation/service_layer/services.py new file mode 100644 index 0000000..da19193 --- /dev/null +++ b/app/allocation/service_layer/services.py @@ -0,0 +1,39 @@ +from datetime import date +from uuid import UUID + +from app.allocation.adapters.repository import AbstractBatchRepository +from app.allocation.domain import models +from app.allocation.service_layer import unit_of_work + + +class InvalidSku(Exception): + pass + + +async def add_batch( + batch_id: UUID, + sku: str, + qty: int, + eta: date | None, + uow: unit_of_work.AbstractUnitOfWork[AbstractBatchRepository], +) -> None: + async with uow: + await uow.repo.add(models.Batch(id=batch_id, sku=sku, qty=qty, eta=eta)) + await uow.commit() + + +async def allocate( + line_id: UUID, sku: str, qty: int, uow: unit_of_work.AbstractUnitOfWork[AbstractBatchRepository] +) -> UUID: + line = models.OrderLine(id=line_id, sku=sku, qty=qty) + async with uow: + batches = await uow.repo.list() + if not _is_valid_sku(line.sku, batches): + raise InvalidSku(f"Invalid sku {line.sku}") + batch_id = models.allocate(line, batches) + await uow.commit() + return batch_id + + +def _is_valid_sku(sku: str, batches: list[models.Batch]) -> bool: + return sku in {b.sku for b in batches} diff --git a/app/allocation/service_layer/unit_of_work.py b/app/allocation/service_layer/unit_of_work.py new file mode 100644 index 0000000..68d7efb --- /dev/null +++ b/app/allocation/service_layer/unit_of_work.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +import abc +from asyncio import current_task +from typing import Any, Generic, TypeVar + +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session, create_async_engine +from sqlalchemy.orm import sessionmaker + +from app.allocation.adapters.repository import AbstractBatchRepository, PGBatchRepository +from app.config import get_config + +config = get_config() + +Repo = TypeVar("Repo") + + +class AbstractUnitOfWork(abc.ABC, Generic[Repo]): + async def __aenter__(self) -> AbstractUnitOfWork[Repo]: + return self + + async def __aexit__(self, *args: Any) -> None: + await self.rollback() + + @abc.abstractproperty + def repo(self) -> Repo: + raise NotImplementedError + + @abc.abstractmethod + async def commit(self) -> None: + raise NotImplementedError + + @abc.abstractmethod + async def rollback(self) -> None: + raise NotImplementedError + + +class BatchUnitOfWork(AbstractUnitOfWork[AbstractBatchRepository]): + def __init__(self) -> None: + self._engine = create_async_engine(config.PG_DSN, echo=False) + self._session_factory = async_scoped_session( + sessionmaker( + autocommit=False, + autoflush=False, + class_=AsyncSession, + bind=self._engine, + ), + scopefunc=current_task, + ) + + @property + def repo(self) -> AbstractBatchRepository: + return self._batches + + async def __aenter__(self) -> AbstractUnitOfWork[AbstractBatchRepository]: + self._session: AsyncSession = self._session_factory() + self._batches = PGBatchRepository(self._session) + return await super().__aenter__() + + async def __aexit__(self, *args: Any) -> None: + await super().__aexit__(*args) + await self._session.close() + + async def commit(self) -> None: + await self._session.commit() + + async def rollback(self) -> None: + await self._session.rollback() diff --git a/app/main.py b/app/main.py deleted file mode 100644 index 9f629d7..0000000 --- a/app/main.py +++ /dev/null @@ -1,29 +0,0 @@ -from uuid import UUID - -from fastapi import Body, Depends, FastAPI, HTTPException - -from app import models, services -from app.dependencies import repository -from app.repository import BatchAbstractRepository - -app = FastAPI() - - -@app.get("/") -async def root() -> dict[str, str]: - return {"message": "Hello World"} - - -@app.post("/allocate", response_model=dict[str, str], status_code=201) -async def allocate( - order_id: UUID = Body(), - sku: str = Body(), - quantity: int = Body(), - repo: BatchAbstractRepository = Depends(repository), -) -> dict[str, str]: - line = models.OrderLine(id=order_id, sku=sku, qty=quantity) - try: - batch_id = await services.allocate(line, repo) - except (models.OutOfStock, services.InvalidSku) as e: - raise HTTPException(status_code=400, detail=str(e)) - return {"batch_id": str(batch_id)} diff --git a/app/services.py b/app/services.py deleted file mode 100644 index 16ac0d2..0000000 --- a/app/services.py +++ /dev/null @@ -1,21 +0,0 @@ -from uuid import UUID - -from app import models -from app.models import OrderLine -from app.repository import BatchAbstractRepository - - -class InvalidSku(Exception): - pass - - -async def allocate(line: OrderLine, repo: BatchAbstractRepository) -> UUID: - batches = await repo.list() - if not _is_valid_sku(line.sku, batches): - raise InvalidSku(f"Invalid sku {line.sku}") - batch_id = models.allocate(line, batches) - return batch_id - - -def _is_valid_sku(sku: str, batches: list[models.Batch]) -> bool: - return sku in {b.sku for b in batches} diff --git a/tests/conftest.py b/tests/conftest.py index 94794a5..4707a9d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,9 +6,9 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from sqlalchemy.orm import clear_mappers +from app.allocation.adapters.db import DB +from app.allocation.adapters.orm import metadata, start_mappers from app.config import get_config -from app.db import DB -from app.orm import metadata, start_mappers config = get_config() db = DB(config.PG_DSN) diff --git a/tests/test_api.py b/tests/e2e/test_api.py similarity index 77% rename from tests/test_api.py rename to tests/e2e/test_api.py index fa868e7..e7aeb1a 100644 --- a/tests/test_api.py +++ b/tests/e2e/test_api.py @@ -1,6 +1,6 @@ +from collections.abc import AsyncGenerator, Generator from datetime import date from typing import Any -from collections.abc import AsyncGenerator, Generator from uuid import UUID, uuid4 import pytest @@ -9,8 +9,8 @@ from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession from sqlalchemy.orm import sessionmaker -from app.main import app -from app.orm import metadata +from app.allocation.adapters.orm import metadata +from app.allocation.routers.main import app @pytest.fixture(scope="session") @@ -33,6 +33,15 @@ async def clear_db(session: AsyncSession) -> AsyncGenerator[Any, Any]: await session.execute(table.delete()) +async def test_add_batch_returns_201(client: TestClient) -> None: + # Given + res = client.post("/batches", json={"batch_id": str(uuid4()), "sku": "SKU", "quantity": 3}) + + # Then + assert res.status_code == 201 + assert res.json() == {"message": "success"} + + async def test_allocate_api_returns_201_and_allocated_batch( session: AsyncSession, client: TestClient ) -> None: @@ -44,10 +53,7 @@ async def test_allocate_api_returns_201_and_allocated_batch( ] for id, sku, qty, eta in batches: await session.execute( - sa.text( - "INSERT INTO batch (id, sku, purchased_quantity, eta) " - "VALUES (:id, :sku, :qty, :eta)" - ), + sa.text("INSERT INTO batch (id, sku, qty, eta) " "VALUES (:id, :sku, :qty, :eta)"), dict(id=id, sku=sku, qty=qty, eta=eta), ) await session.execute( @@ -57,7 +63,7 @@ async def test_allocate_api_returns_201_and_allocated_batch( await session.commit() # When - res = client.post("/allocate", json={"order_id": str(uuid4()), "sku": "SKU", "quantity": 3}) + res = client.post("/allocate", json={"line_id": str(uuid4()), "sku": "SKU", "quantity": 3}) # Then: order line is allocated to the batch with earliest eta, and status code 201 assert res.status_code == 201 @@ -69,7 +75,7 @@ async def test_allocate_api_returns_400_and_error_message_if_invalid_sku( ) -> None: # When: request with invalid sku res = client.post( - "/allocate", json={"order_id": str(uuid4()), "sku": "NOT-EXIST-SKU", "quantity": 3} + "/allocate", json={"line_id": str(uuid4()), "sku": "NOT-EXIST-SKU", "quantity": 3} ) # Then: status code 400 and error message diff --git a/tests/test_orm.py b/tests/integration/test_orm.py similarity index 83% rename from tests/test_orm.py rename to tests/integration/test_orm.py index e55b3ff..504a94e 100644 --- a/tests/test_orm.py +++ b/tests/integration/test_orm.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from app import models +from app.allocation.domain import models async def test_order_line_mapper_can_load_lines(session: AsyncSession) -> None: @@ -54,17 +54,17 @@ async def test_retrieving_batches(session: AsyncSession) -> None: # Given await session.execute( sa.text( - "INSERT INTO batch (id, sku, purchased_quantity, eta) VALUES " + "INSERT INTO batch (id, sku, qty, eta) VALUES " "('fa3310b7-2a44-4f0b-be0e-cf6ba8e201fb', 'RETRO-CLOCK', 100, null)," "('db14a922-e95b-481c-ac67-476ca819e96d', 'MINIMALIST-SPOON', 100, '2011-01-02')" ) ) expected = [ - models.Batch(UUID("fa3310b7-2a44-4f0b-be0e-cf6ba8e201fb"), "RETRO-CLOCK", 100, eta=None), + models.Batch(id=UUID("fa3310b7-2a44-4f0b-be0e-cf6ba8e201fb"), sku="RETRO-CLOCK", qty=100), models.Batch( - UUID("db14a922-e95b-481c-ac67-476ca819e96d"), - "MINIMALIST-SPOON", - 100, + id=UUID("db14a922-e95b-481c-ac67-476ca819e96d"), + sku="MINIMALIST-SPOON", + qty=100, eta=date(2011, 1, 2), ), ] @@ -78,14 +78,16 @@ async def test_retrieving_batches(session: AsyncSession) -> None: async def test_saving_batches(session: AsyncSession) -> None: # Given - batch = models.Batch(UUID("c7b6f091-bc25-458b-9027-1aa52fc3d9e1"), "RETRO-CLOCK", 100, eta=None) + batch = models.Batch( + id=UUID("c7b6f091-bc25-458b-9027-1aa52fc3d9e1"), sku="RETRO-CLOCK", qty=100 + ) # When session.add(batch) await session.flush() # Then - rows = await session.execute(sa.text("SELECT id, sku, purchased_quantity, eta FROM batch")) + rows = await session.execute(sa.text("SELECT id, sku, qty, eta FROM batch")) assert rows.fetchall() == [ (UUID("c7b6f091-bc25-458b-9027-1aa52fc3d9e1"), "RETRO-CLOCK", 100, None) ] @@ -93,14 +95,17 @@ async def test_saving_batches(session: AsyncSession) -> None: async def test_saving_allocations(session: AsyncSession) -> None: # Given - batch = models.Batch(UUID("9ef1794c-617b-4634-82e6-dda1f466ea72"), "RETRO-CLOCK", 100, eta=None) + batch = models.Batch( + id=UUID("9ef1794c-617b-4634-82e6-dda1f466ea72"), sku="RETRO-CLOCK", qty=100 + ) line = models.OrderLine( - sku="RETRO-CLOCK", qty=10, id=UUID("591bf188-50e8-4279-904a-bcec100d966b") + id=UUID("591bf188-50e8-4279-904a-bcec100d966b"), sku="RETRO-CLOCK", qty=10 ) # When batch.allocate(line) session.add(batch) + # 'add' method doesn't automatically flushed, so manual flush is required to refer batch await session.flush() # Then @@ -126,7 +131,7 @@ async def test_retrieving_allocations(session: AsyncSession) -> None: ) await session.execute( sa.text( - "INSERT INTO batch (id, sku, purchased_quantity, eta) VALUES " + "INSERT INTO batch (id, sku, qty, eta) VALUES " "('9c5d341f-4876-4a54-81f7-720a390884fb', 'RETRO-CLOCK', 100, null)" ) ) diff --git a/tests/test_repository.py b/tests/integration/test_repository.py similarity index 64% rename from tests/test_repository.py rename to tests/integration/test_repository.py index 49f12ab..e5825f0 100644 --- a/tests/test_repository.py +++ b/tests/integration/test_repository.py @@ -3,7 +3,8 @@ import sqlalchemy as sa from sqlalchemy.ext.asyncio import AsyncSession -from app import models, repository +from app.allocation.adapters import repository +from app.allocation.domain import models async def test_repository_can_save_a_batch_with_allocations(session: AsyncSession) -> None: @@ -11,7 +12,9 @@ async def test_repository_can_save_a_batch_with_allocations(session: AsyncSessio order_line = models.OrderLine( id=UUID("13785f9d-c4e1-4e2b-a789-bc8d2309ce34"), sku="RETRO-CLOCK", qty=10 ) - batch = models.Batch(UUID("b463867c-e573-4a71-bd51-282f32763ee9"), "RETRO-CLOCK", 100, eta=None) + batch = models.Batch( + id=UUID("b463867c-e573-4a71-bd51-282f32763ee9"), sku="RETRO-CLOCK", qty=100 + ) batch.allocate(order_line) repo = repository.PGBatchRepository(session) @@ -19,7 +22,7 @@ async def test_repository_can_save_a_batch_with_allocations(session: AsyncSessio await repo.add(batch) # Then - rows = await session.execute(sa.text("SELECT id, sku, purchased_quantity, eta FROM batch")) + rows = await session.execute(sa.text("SELECT id, sku, qty, eta FROM batch")) assert list(rows) == [(UUID("b463867c-e573-4a71-bd51-282f32763ee9"), "RETRO-CLOCK", 100, None)] rows = await session.execute(sa.text("SELECT order_line_id, batch_id FROM allocation")) assert list(rows) == [ @@ -29,18 +32,16 @@ async def test_repository_can_save_a_batch_with_allocations(session: AsyncSessio async def test_repository_can_retrieve_a_batch_with_allocations(session: AsyncSession) -> None: # Given: create two batches and allocate one to an order - order_line = models.OrderLine( - id=UUID("97c5289f-aa4e-4c5f-8c9b-4b9f7597a7bd"), sku="RETRO-CLOCK", qty=10 - ) - batch1 = models.Batch(UUID("0194c5bc-20af-4fd1-82bf-324e5f26fce7"), "RETRO-CLOCK", 100, None) - batch2 = models.Batch( - UUID("ecd13d3d-c9ee-4a4f-9b1f-c9356ac668ab"), "MINIMALIST-SPOON", 100, None + order_line = models.OrderLine(sku="RETRO-CLOCK", qty=10) + batch1 = models.Batch( + id=UUID("0194c5bc-20af-4fd1-82bf-324e5f26fce7"), sku="RETRO-CLOCK", qty=100 ) + batch2 = models.Batch(sku="MINIMALIST-SPOON", qty=100) session.add(order_line) session.add(batch1) session.add(batch2) batch1.allocate(order_line) - # add is not automatically flushed, so I manual flush to refer to the batch in the next step. + # 'add' method doesn't automatically flushed, so manual flush is required to refer batch await session.flush() # When @@ -50,18 +51,18 @@ async def test_repository_can_retrieve_a_batch_with_allocations(session: AsyncSe # Then assert result.id == UUID("0194c5bc-20af-4fd1-82bf-324e5f26fce7") assert result.sku == "RETRO-CLOCK" - assert result.purchased_quantity == 100 + assert result.qty == 100 assert result.available_quantity == 90 assert list(result.allocations) == [order_line] async def test_repository_can_fetch_batch_list(session: AsyncSession) -> None: # Given: create two batches and allocate one to an order - batch1 = models.Batch(UUID("5518eac3-b214-448a-bd1c-f49bb92c4ed9"), "SKU-1", 100, None) - batch2 = models.Batch(UUID("48ad28ea-d799-4b05-9a18-e3b5b34212a6"), "SKU-2", 100, None) + batch1 = models.Batch(sku="SKU-1", qty=100) + batch2 = models.Batch(sku="SKU-2", qty=100) session.add(batch1) session.add(batch2) - # add is not automatically flushed, so I manual flush to refer to the batch in the next step. + # 'add' method doesn't automatically flushed, so manual flush is required to refer batch await session.flush() # When diff --git a/tests/integration/test_uow.py b/tests/integration/test_uow.py new file mode 100644 index 0000000..45f0c15 --- /dev/null +++ b/tests/integration/test_uow.py @@ -0,0 +1,86 @@ +from typing import Any +from collections.abc import AsyncGenerator +from uuid import UUID, uuid4 + +import pytest +import sqlalchemy as sa +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession +from sqlalchemy.orm import sessionmaker + +from app.allocation.adapters.orm import metadata +from app.allocation.domain import models +from app.allocation.service_layer.unit_of_work import BatchUnitOfWork + + +@pytest.fixture +async def session(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, Any]: + session = sessionmaker(bind=engine, class_=AsyncSession)() + yield session + await session.close() + + +@pytest.fixture(autouse=True) +async def clear_db(session: AsyncSession) -> AsyncGenerator[Any, Any]: + yield session + for table in reversed(metadata.sorted_tables): + await session.execute(table.delete()) + + +async def test_uow_can_retrieve_a_batch_and_allocate_line_to_batch(session: AsyncSession) -> None: + # Given + await session.execute( + sa.text( + "INSERT INTO batch (id, sku, qty, eta) VALUES " + "('9c5d341f-4876-4a54-81f7-720a390884fb', 'RETRO-CLOCK', 100, null)" + ) + ) + await session.commit() + uow = BatchUnitOfWork() + + # When + async with uow: + batch = await uow.repo.get(UUID("9c5d341f-4876-4a54-81f7-720a390884fb")) + line = models.OrderLine( + id=UUID("cd7b01d0-00b5-4237-b1f6-ebb8c50dd7da"), sku="RETRO-CLOCK", qty=10 + ) + batch.allocate(line) + await uow.commit() + + # Then + [[batch_id]] = await session.execute( + sa.text("SELECT batch_id FROM allocation"), + dict( + order_line_id=UUID("cd7b01d0-00b5-4237-b1f6-ebb8c50dd7da"), + batch_id=UUID("9c5d341f-4876-4a54-81f7-720a390884fb"), + ), + ) + assert batch_id == UUID("9c5d341f-4876-4a54-81f7-720a390884fb") + + +async def test_rollback_uncommitted_work_by_default(session: AsyncSession) -> None: + # Given + uow = BatchUnitOfWork() + + # When + async with uow: + batch = models.Batch(id=uuid4(), sku="RETRO-CLOCK", qty=100, eta=None) + await uow.repo.add(batch) + + # Then + result = await session.execute(sa.text("SELECT * FROM batch")) + assert list(result) == [] + + +async def test_rollsback_on_error(session: AsyncSession) -> None: + class MyException(Exception): + pass + + uow = BatchUnitOfWork() + with pytest.raises(MyException): + async with uow: + batch = models.Batch(id=uuid4(), sku="RETRO-CLOCK", qty=100, eta=None) + await uow.repo.add(batch) + raise MyException() + + result = await session.execute(sa.text("SELECT * FROM batch")) + assert list(result) == [] diff --git a/tests/test_allocate.py b/tests/test_allocate.py deleted file mode 100644 index c811ed6..0000000 --- a/tests/test_allocate.py +++ /dev/null @@ -1,59 +0,0 @@ -from datetime import date, timedelta -from uuid import UUID, uuid4 - -import pytest - -from app.models import Batch, OrderLine, OutOfStock, allocate - - -def test_perfers_current_stock_batches_to_allocate() -> None: - # Given - in_stock_batch = Batch(uuid4(), "RETRO-CLOCK", 100, None) - shipment_batch = Batch(uuid4(), "RETRO-CLOCK", 100, date.today() + timedelta(days=1)) - line = OrderLine(id=UUID("5dfffbc2-fca9-49ef-a1c6-bfaba18e0fee"), sku="RETRO-CLOCK", qty=10) - - # When - allocate(line, [in_stock_batch, shipment_batch]) - - # Then - assert in_stock_batch.available_quantity == 90 - assert shipment_batch.available_quantity == 100 - - -def test_prefers_earlier_batches() -> None: - # Given - earliest = Batch(uuid4(), "MINIMALIST-SPOON", 100, eta=date.today()) - medium = Batch(uuid4(), "MINIMALIST-SPOON", 100, eta=date.today() + timedelta(days=1)) - latest = Batch(uuid4(), "MINIMALIST-SPOON", 100, eta=date.today() + timedelta(days=2)) - line = OrderLine(id=uuid4(), sku="MINIMALIST-SPOON", qty=10) - - # When - allocate(line, [medium, earliest, latest]) - - # Then - assert earliest.available_quantity == 90 - assert medium.available_quantity == 100 - assert latest.available_quantity == 100 - - -def test_returns_allocated_batch_ref() -> None: - # Given - in_stock_batch = Batch(uuid4(), "HIGHBROW-POSTER", 100, eta=None) - shipment_batch = Batch(uuid4(), "HIGHBROW-POSTER", 100, eta=date.today() + timedelta(days=1)) - line = OrderLine(id=uuid4(), sku="HIGHBROW-POSTER", qty=10) - - # When - batch_id = allocate(line, [in_stock_batch, shipment_batch]) - - # Then - assert batch_id == in_stock_batch.id - - -def test_raises_out_of_stock_exception_if_cannot_allocate() -> None: - # Given - batch = Batch(uuid4(), "SMALL-FORK", 10, eta=None) - allocate(OrderLine(id=uuid4(), sku="SMALL-FORK", qty=10), [batch]) - - # When & Then - with pytest.raises(OutOfStock, match="SMALL-FORK"): - allocate(OrderLine(id=uuid4(), sku="SMALL-FORK", qty=1), [batch]) diff --git a/tests/test_main.py b/tests/test_main.py index 03518e0..2d7dd55 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -6,7 +6,7 @@ @pytest.fixture() async def client() -> AsyncGenerator[TestClient, None]: - from app.main import app + from app.allocation.routers.main import app yield TestClient(app) diff --git a/tests/test_service.py b/tests/test_service.py deleted file mode 100644 index 388ccaf..0000000 --- a/tests/test_service.py +++ /dev/null @@ -1,44 +0,0 @@ -from uuid import UUID - -import pytest - -from app import models, services -from app.repository import BatchAbstractRepository - - -class FakeRepository(BatchAbstractRepository): - def __init__(self, batches: list[models.Batch]) -> None: - self._batches = set(batches) - - def add(self, batch: models.Batch) -> None: - self._batches.add(batch) - - async def get(self, id: UUID) -> models.Batch: - return next(b for b in self._batches if b.id == id) - - async def list(self) -> list[models.Batch]: - return list(self._batches) - - -async def test_returns_allocation() -> None: - # Given - lint = models.OrderLine(id=UUID("4382a3d5-e3eb-44cd-972b-3a90e793060b"), sku="SKU", qty=10) - batch = models.Batch(UUID("f0e9d78e-ccc7-4f9b-a0e9-4b286b6d8ca5"), "SKU", 100, eta=None) - repo = FakeRepository([batch]) - - # When - result = await services.allocate(lint, repo) - - # Then - assert result == UUID("f0e9d78e-ccc7-4f9b-a0e9-4b286b6d8ca5") - - -async def test_error_for_invalid_sku() -> None: - # Given - lint = models.OrderLine(id=UUID("7970dba1-6d92-47e8-9664-f512493febfc"), sku="UNKNOWN", qty=10) - batch = models.Batch(UUID("1cb92bac-98aa-421c-afe6-7cbba08b050c"), "SKU", 100, eta=None) - repo = FakeRepository([batch]) - - # When & Then - with pytest.raises(services.InvalidSku, match="Invalid sku UNKNOWN"): - await services.allocate(lint, repo) diff --git a/tests/unit/test_allocate.py b/tests/unit/test_allocate.py new file mode 100644 index 0000000..c6f40cc --- /dev/null +++ b/tests/unit/test_allocate.py @@ -0,0 +1,58 @@ +from datetime import date, timedelta + +import pytest + +from app.allocation.domain.models import Batch, OrderLine, OutOfStock, allocate + + +def test_perfers_current_stock_batches_to_allocate() -> None: + # Given + in_stock_batch = Batch(sku="RETRO-CLOCK", qty=100) + shipment_batch = Batch(sku="RETRO-CLOCK", qty=100, eta=date.today() + timedelta(days=1)) + line = OrderLine(sku="RETRO-CLOCK", qty=10) + + # When + allocate(line, [in_stock_batch, shipment_batch]) + + # Then + assert in_stock_batch.available_quantity == 90 + assert shipment_batch.available_quantity == 100 + + +def test_prefers_earlier_batches() -> None: + # Given + earliest = Batch(sku="MINIMALIST-SPOON", qty=100, eta=date.today()) + medium = Batch(sku="MINIMALIST-SPOON", qty=100, eta=date.today() + timedelta(days=1)) + latest = Batch(sku="MINIMALIST-SPOON", qty=100, eta=date.today() + timedelta(days=2)) + line = OrderLine(sku="MINIMALIST-SPOON", qty=10) + + # When + allocate(line, [medium, earliest, latest]) + + # Then + assert earliest.available_quantity == 90 + assert medium.available_quantity == 100 + assert latest.available_quantity == 100 + + +def test_returns_allocated_batch_ref() -> None: + # Given + in_stock_batch = Batch(sku="HIGHBROW-POSTER", qty=100) + shipment_batch = Batch(sku="HIGHBROW-POSTER", qty=100, eta=date.today() + timedelta(days=1)) + line = OrderLine(sku="HIGHBROW-POSTER", qty=10) + + # When + batch_id = allocate(line, [in_stock_batch, shipment_batch]) + + # Then + assert batch_id == in_stock_batch.id + + +def test_raises_out_of_stock_exception_if_cannot_allocate() -> None: + # Given + batch = Batch(sku="SMALL-FORK", qty=10) + allocate(OrderLine(sku="SMALL-FORK", qty=10), [batch]) + + # When & Then + with pytest.raises(OutOfStock, match="SMALL-FORK"): + allocate(OrderLine(sku="SMALL-FORK", qty=1), [batch]) diff --git a/tests/test_batches.py b/tests/unit/test_batches.py similarity index 59% rename from tests/test_batches.py rename to tests/unit/test_batches.py index 8a9042a..5511aca 100644 --- a/tests/test_batches.py +++ b/tests/unit/test_batches.py @@ -1,12 +1,10 @@ -from uuid import uuid4 - -from app.models import Batch, OrderLine +from app.allocation.domain.models import Batch, OrderLine def test_allocating_to_a_batch_reduces_available_quantity() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 10, eta=None) - line = OrderLine(id=uuid4(), sku="SMALL-FORK", qty=10) + batch = Batch(sku="SMALL-FORK", qty=10) + line = OrderLine(sku="SMALL-FORK", qty=10) # When batch.allocate(line) @@ -17,8 +15,8 @@ def test_allocating_to_a_batch_reduces_available_quantity() -> None: def test_can_allocate_if_available_greater_than_required() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 10, eta=None) - line = OrderLine(id=uuid4(), sku="SMALL-FORK", qty=1) + batch = Batch(sku="SMALL-FORK", qty=10) + line = OrderLine(sku="SMALL-FORK", qty=1) # When can_allocate = batch.can_allocate(line) @@ -29,8 +27,8 @@ def test_can_allocate_if_available_greater_than_required() -> None: def test_cannot_allocate_if_available_smaller_than_required() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 1, eta=None) - line = OrderLine(id=uuid4(), sku="SMALL-FORK", qty=10) + batch = Batch(sku="SMALL-FORK", qty=1) + line = OrderLine(sku="SMALL-FORK", qty=10) # When can_allocate = batch.can_allocate(line) @@ -41,8 +39,8 @@ def test_cannot_allocate_if_available_smaller_than_required() -> None: def test_can_allocate_if_available_equal_to_required() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 1, eta=None) - line = OrderLine(id=uuid4(), sku="SMALL-FORK", qty=1) + batch = Batch(sku="SMALL-FORK", qty=1) + line = OrderLine(sku="SMALL-FORK", qty=1) # When can_allocate = batch.can_allocate(line) @@ -53,8 +51,8 @@ def test_can_allocate_if_available_equal_to_required() -> None: def test_cannot_allocate_if_skus_do_not_match() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 10, eta=None) - line = OrderLine(id=uuid4(), sku="LARGE-FORK", qty=10) + batch = Batch(sku="SMALL-FORK", qty=10) + line = OrderLine(sku="LARGE-FORK", qty=10) # When can_allocate = batch.can_allocate(line) @@ -65,8 +63,8 @@ def test_cannot_allocate_if_skus_do_not_match() -> None: def test_allocation_is_idempotent() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 10, eta=None) - line = OrderLine(id=uuid4(), sku="SMALL-FORK", qty=10) + batch = Batch(sku="SMALL-FORK", qty=10) + line = OrderLine(sku="SMALL-FORK", qty=10) # When batch.allocate(line) @@ -81,8 +79,8 @@ def test_allocation_is_idempotent() -> None: def test_deallocate() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 10, eta=None) - line = OrderLine(id=uuid4(), sku="SMALL-FORK", qty=10) + batch = Batch(sku="SMALL-FORK", qty=10) + line = OrderLine(sku="SMALL-FORK", qty=10) # When batch.allocate(line) @@ -94,8 +92,8 @@ def test_deallocate() -> None: def test_can_only_deallocate_allocated_lines() -> None: # Given - batch = Batch(uuid4(), "SMALL-FORK", 10, eta=None) - line = OrderLine(id=uuid4(), sku="SMALL-FORK", qty=10) + batch = Batch(sku="SMALL-FORK", qty=10) + line = OrderLine(sku="SMALL-FORK", qty=10) # When batch.deallocate(line) diff --git a/tests/unit/test_services.py b/tests/unit/test_services.py new file mode 100644 index 0000000..b715438 --- /dev/null +++ b/tests/unit/test_services.py @@ -0,0 +1,91 @@ +from typing import Any +from uuid import UUID, uuid4 + +import pytest + +from app.allocation.adapters.repository import AbstractBatchRepository +from app.allocation.domain import models +from app.allocation.service_layer import services, unit_of_work + + +class FakeRepository(AbstractBatchRepository): + def __init__(self, batches: list[models.Batch]) -> None: + self._batches = set(batches) + + async def add(self, batch: models.Batch) -> None: + self._batches.add(batch) + + async def get(self, id: UUID) -> models.Batch: + return next(b for b in self._batches if b.id == id) + + async def list(self) -> list[models.Batch]: + return list(self._batches) + + +class FakeUnitOfWork(unit_of_work.AbstractUnitOfWork[AbstractBatchRepository]): + def __init__(self) -> None: + self.batches = FakeRepository([]) + self.committed = False + + async def __aexit__(self, *args: Any) -> None: + await self.rollback() + + @property + def repo(self) -> AbstractBatchRepository: + return self.batches + + async def commit(self) -> None: + self.committed = True + + async def rollback(self) -> None: + pass + + +async def test_add_batch() -> None: + # Given + uow = FakeUnitOfWork() + + # When + await services.add_batch( + UUID("3a80e50e-22f4-4907-afb3-8aa37e0c27b8"), "CRUNCHY-ARMCHAIR", 100, None, uow + ) + + # Then + assert await uow.batches.get(UUID("3a80e50e-22f4-4907-afb3-8aa37e0c27b8")) is not None + assert uow.committed + + +async def test_allocate_returns_allocation() -> None: + # Given + uow = FakeUnitOfWork() + await services.add_batch( + UUID("f0e9d78e-ccc7-4f9b-a0e9-4b286b6d8ca5"), "OMINOUS-MIRROR", 100, None, uow + ) + + # When + result = await services.allocate(uuid4(), "OMINOUS-MIRROR", 10, uow) + + # Then + assert result == UUID("f0e9d78e-ccc7-4f9b-a0e9-4b286b6d8ca5") + + +async def test_allocate_error_for_invalid_sku() -> None: + # Given + uow = FakeUnitOfWork() + await services.add_batch(uuid4(), "SKU", 100, None, uow) + + # When & Then + with pytest.raises(services.InvalidSku, match="Invalid sku UNKNOWN"): + await services.allocate(uuid4(), "UNKNOWN", 10, uow) + + +async def test_allocate_commits() -> None: + # Given + uow = FakeUnitOfWork() + await services.add_batch(uuid4(), "OMINOUS-MIRROR", 100, None, uow) + + # When + await services.allocate(uuid4(), "OMINOUS-MIRROR", 10, uow) + + # Then + assert uow.committed