Skip to content

Commit

Permalink
Add init_memory_sqlite decorator (#1657)
Browse files Browse the repository at this point in the history
* Add init_memory_sqlite function

* Rename and fixing codacy issues

* fixing codacy issue

* fixing codacy issues

* Use mock to test memory sqlite init decorator

* Add AsyncFunc type for init_memory_sqlite

* Support custom models

* Update changelog

* fix ci error
  • Loading branch information
waketzheng committed Jun 24, 2024
1 parent b87f485 commit 8848b20
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 11 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Added
- DoesNotExist and MultipleObjectsReturned support 'Type[Model]' argument. (#742)(#1650)
- Add argument use_tz and timezone to RegisterTortoise. (#1649)
- Support await `tortoise.contrib.fastapi.RegisterTortoise`. (#1662)
- Add `tortoise.contrib.test.init_memory_sqlite`. (#1657)

Fixed
^^^^^
Expand Down
9 changes: 4 additions & 5 deletions examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
This example demonstrates most basic operations with single model
"""

from tortoise import Tortoise, fields, run_async
from tortoise import fields, run_async
from tortoise.contrib.test import init_memory_sqlite
from tortoise.models import Model


Expand All @@ -18,10 +19,8 @@ def __str__(self):
return self.name


async def run():
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ["__main__"]})
await Tortoise.generate_schemas()

@init_memory_sqlite
async def run() -> None:
event = await Event.create(name="Test")
await Event.filter(id=event.id).update(name="Updated name")

Expand Down
74 changes: 74 additions & 0 deletions tests/contrib/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import subprocess # nosec
from unittest.mock import AsyncMock, patch

from tortoise.contrib import test
from tortoise.contrib.test import init_memory_sqlite


class TestDecorator(test.TestCase):
@test.requireCapability(dialect="sqlite")
async def test_script_with_init_memory_sqlite(self) -> None:
r = subprocess.run(["python", "examples/basic.py"], capture_output=True) # nosec
output = r.stdout.decode()
s = "[{'id': 1, 'name': 'Updated name'}, {'id': 2, 'name': 'Test 2'}]"
self.assertIn(s, output)

@test.requireCapability(dialect="sqlite")
@patch("tortoise.Tortoise.init")
@patch("tortoise.Tortoise.generate_schemas")
async def test_init_memory_sqlite(
self,
mocked_generate: AsyncMock,
mocked_init: AsyncMock,
) -> None:
@init_memory_sqlite
async def run():
return "foo"

res = await run()
self.assertEqual(res, "foo")
mocked_init.assert_awaited_once()
mocked_init.assert_called_once_with(
db_url="sqlite://:memory:", modules={"models": ["__main__"]}
)
mocked_generate.assert_awaited_once()

@test.requireCapability(dialect="sqlite")
@patch("tortoise.Tortoise.init")
@patch("tortoise.Tortoise.generate_schemas")
async def test_init_memory_sqlite_with_models(
self,
mocked_generate: AsyncMock,
mocked_init: AsyncMock,
) -> None:
@init_memory_sqlite(["app.models"])
async def run():
return "foo"

res = await run()
self.assertEqual(res, "foo")
mocked_init.assert_awaited_once()
mocked_init.assert_called_once_with(
db_url="sqlite://:memory:", modules={"models": ["app.models"]}
)
mocked_generate.assert_awaited_once()

@test.requireCapability(dialect="sqlite")
@patch("tortoise.Tortoise.init")
@patch("tortoise.Tortoise.generate_schemas")
async def test_init_memory_sqlite_model_str(
self,
mocked_generate: AsyncMock,
mocked_init: AsyncMock,
) -> None:
@init_memory_sqlite("app.models")
async def run():
return "foo"

res = await run()
self.assertEqual(res, "foo")
mocked_init.assert_awaited_once()
mocked_init.assert_called_once_with(
db_url="sqlite://:memory:", modules={"models": ["app.models"]}
)
mocked_generate.assert_awaited_once()
2 changes: 1 addition & 1 deletion tests/contrib/test_fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class TestRegisterTortoise(test.TestCase):
@test.requireCapability(dialect="sqlite") # type:ignore[misc]
@test.requireCapability(dialect="sqlite")
@patch("tortoise.Tortoise.init")
@patch("tortoise.connections.close_all")
async def test_await(
Expand Down
90 changes: 85 additions & 5 deletions tortoise/contrib/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import asyncio
import inspect
import os as _os
import sys
import typing
import unittest
from asyncio.events import AbstractEventLoop
from functools import wraps
from functools import partial, wraps
from types import ModuleType
from typing import Any, Iterable, List, Optional, Union
from typing import Any, Callable, Coroutine, Iterable, List, Optional, TypeVar, Union
from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless

from tortoise import Model, Tortoise, connections
from tortoise.backends.base.config_generator import generate_config as _generate_config
from tortoise.exceptions import DBConnectionError, OperationalError

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec


__all__ = (
"SimpleTestCase",
"TestCase",
Expand All @@ -27,6 +35,7 @@
"skip",
"skipIf",
"skipUnless",
"init_memory_sqlite",
)
_TORTOISE_TEST_DB = "sqlite://:memory:"
# pylint: disable=W0201
Expand Down Expand Up @@ -338,7 +347,7 @@ async def _tearDownDB(self) -> None:
await super()._tearDownDB()


def requireCapability(connection_name: str = "models", **conditions: Any):
def requireCapability(connection_name: str = "models", **conditions: Any) -> Callable:
"""
Skip a test if the required capabilities are not matched.
Expand Down Expand Up @@ -393,9 +402,9 @@ def skip_wrapper(*args, **kwargs):

# Assume a class is decorated
funcs = {
var: getattr(test_item, var)
var: f
for var in dir(test_item)
if var.startswith("test_") and callable(getattr(test_item, var))
if var.startswith("test_") and callable(f := getattr(test_item, var))
}
for name, func in funcs.items():
setattr(
Expand All @@ -407,3 +416,74 @@ def skip_wrapper(*args, **kwargs):
return test_item

return decorator


T = TypeVar("T")
P = ParamSpec("P")
AsyncFunc = Callable[P, Coroutine[None, None, T]]
AsyncFuncDeco = Callable[..., AsyncFunc]
ModulesConfigType = Union[str, List[str]]


@typing.overload
def init_memory_sqlite(models: Union[ModulesConfigType, None] = None) -> AsyncFuncDeco: ...


@typing.overload
def init_memory_sqlite(models: AsyncFunc) -> AsyncFunc: ...


def init_memory_sqlite(
models: Union[ModulesConfigType, AsyncFunc, None] = None
) -> Union[AsyncFunc, AsyncFuncDeco]:
"""
For single file style to run code with memory sqlite
:param models: list_of_modules that should be discovered for models, default to ['__main__'].
Usage:
.. code-block:: python3
from tortoise import fields, models, run_async
from tortoise.contrib.test import init_memory_sqlite
class MyModel(models.Model):
id = fields.IntField(primary_key=True)
name = fields.TextField()
@init_memory_sqlite
async def run():
obj = await MyModel.create(name='')
assert obj.id == 1
if __name__ == '__main__'
run_async(run)
Custom models example:
.. code-block:: python3
@init_memory_sqlite(models=['app.models', 'aerich.models'])
async def run():
...
"""

def wrapper(func: AsyncFunc, ms: List[str]):
@wraps(func)
async def runner(*args, **kwargs) -> T:
await Tortoise.init(db_url="sqlite://:memory:", modules={"models": ms})
await Tortoise.generate_schemas()
return await func(*args, **kwargs)

return runner

default_models = ["__main__"]
if inspect.iscoroutinefunction(models):
return wrapper(models, default_models)
if models is None:
models = default_models
elif isinstance(models, str):
models = [models]
return partial(wrapper, ms=models)

0 comments on commit 8848b20

Please sign in to comment.