Skip to content

Commit

Permalink
Type annotations & tests (#67)
Browse files Browse the repository at this point in the history
* Add more typing

* A few more tests

* More typing

* Tests for init at 100% coverage

* More model compilation tests

* Update deps

* Added some Q tests

* Silence codacy
  • Loading branch information
grigi authored and abondar committed Nov 23, 2018
1 parent fc91386 commit 98c37cd
Show file tree
Hide file tree
Showing 28 changed files with 492 additions and 213 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkfiles = tortoise/ examples/ setup.py conftest.py
mypy_flags = --warn-unused-configs --warn-redundant-casts --ignore-missing-imports --allow-untyped-decorators
mypy_flags = --warn-unused-configs --warn-redundant-casts --ignore-missing-imports --allow-untyped-decorators --no-implicit-optional

help:
@echo "Tortoise ORM development makefile"
Expand Down
7 changes: 2 additions & 5 deletions docs/roadmap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ For ``v1.0`` that involves:
Mid-term
========

Here we have all the features that is slightly further our:
Here we have all the features that is slightly further out, in no particular order:

* Performance work:
* Sub queries
* Bulk operations
* ...
* Consider using Cython to accelerate critical loops

* Convenience/Ease-Of-Use work:
* Make ``DELETE`` honour ``limit`` and ``offset``
Expand All @@ -39,9 +39,6 @@ Here we have all the features that is slightly further our:
* Make it easier to do simple aggregations
* Expand annotation framework to add statistical functions

* Transaction framework
* Ability to set ACID conformance expectations

* Migrations
* Comprehensive schema in Migrations
* Automatic forward Migration building
Expand Down
18 changes: 9 additions & 9 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ aiosqlite==0.8.0
alabaster==0.7.12 # via sphinx
asn1crypto==0.24.0 # via cryptography
astroid==2.0.4 # via pylint
asyncpg==0.18.1
asyncpg==0.18.2
asynctest==0.12.2
atomicwrites==1.2.1 # via pytest
attrs==18.2.0 # via pytest
Expand All @@ -24,9 +24,9 @@ ciso8601==2.1.1
click==7.0 # via pip-tools
cloud-sptheme==1.9.4
colorama==0.4.0 # via green
coverage==4.5.1 # via coveralls, green, nose2
coverage==4.5.2 # via coveralls, green, nose2
coveralls==1.5.1
cryptography==2.3.1 # via pymysql
cryptography==2.4.1 # via pymysql
docopt==0.6.2 # via coveralls
docutils==0.14
filelock==3.0.10 # via tox
Expand Down Expand Up @@ -58,24 +58,24 @@ pygments==2.2.0
pylint==2.1.1
pymysql==0.9.2 # via aiomysql
pyparsing==2.3.0 # via packaging
pypika==0.16.1
pytest==3.10.0
pypika==0.18.3
pytest==4.0.0
pytz==2018.7 # via babel
pyyaml==3.13 # via bandit
pyyaml==3.13
requests==2.20.1 # via coveralls, sphinx
six==1.11.0 # via astroid, bandit, cryptography, more-itertools, nose2, packaging, pip-tools, pytest, sphinx, stevedore, tox
smmap2==2.0.5 # via gitdb2
snowballstemmer==1.2.1 # via sphinx
sphinx-autodoc-typehints==1.3.0
sphinx==1.8.1
sphinx-autodoc-typehints==1.4.0
sphinx==1.8.2
sphinxcontrib-websupport==1.1.0 # via sphinx
stevedore==1.30.0 # via bandit
termstyle==0.1.11 # via green
testfixtures==6.3.0 # via flake8-isort
toml==0.10.0 # via tox
tox==3.5.3
typed-ast==1.1.0 # via astroid, mypy
unidecode==1.0.22 # via green
unidecode==1.0.23 # via green
urllib3==1.24.1 # via requests
virtualenv==16.1.0 # via tox
wrapt==1.10.11 # via astroid
Expand Down
8 changes: 4 additions & 4 deletions requirements-pypy.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ asynctest==0.12.2
cffi==1.11.5 # via cryptography
ciso8601==2.1.1
colorama==0.4.0 # via green
coverage==4.5.1 # via green
cryptography==2.3.1 # via pymysql
coverage==4.5.2 # via green
cryptography==2.4.1 # via pymysql
green==2.13.0
idna==2.7 # via cryptography
pycparser==2.19 # via cffi
pymysql==0.9.2 # via aiomysql
pypika==0.16.1
pypika==0.18.3
pyyaml==3.13
six==1.11.0 # via cryptography
termstyle==0.1.11 # via green
unidecode==1.0.22 # via green
unidecode==1.0.23 # via green
62 changes: 32 additions & 30 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import os
from copy import deepcopy
from inspect import isclass
from typing import Dict, List, Optional, Type # noqa
from typing import Coroutine, Dict, List, Optional, Type, Union, cast # noqa

from tortoise import fields
from tortoise.backends.base.client import BaseDBAsyncClient
from tortoise.backends.base.config_generator import expand_db_url, generate_config
from tortoise.exceptions import ConfigurationError # noqa
from tortoise.fields import ManyToManyRelationManager # noqa
Expand All @@ -22,16 +23,16 @@


class Tortoise:
apps = {} # type: dict
_connections = {} # type: dict
_inited = False
apps = {} # type: Dict[str, Dict[str, Type[Model]]]
_connections = {} # type: Dict[str, BaseDBAsyncClient]
_inited = False # type: bool

@classmethod
def get_connection(cls, connection_name):
def get_connection(cls, connection_name: str) -> BaseDBAsyncClient:
return cls._connections[connection_name]

@classmethod
def _init_relations(cls):
def _init_relations(cls) -> None:
for app_name, app in cls.apps.items():
for model_name, model in app.items():
if model._meta._inited:
Expand All @@ -42,7 +43,7 @@ def _init_relations(cls):
model._meta.table = model.__name__.lower()

for field in model._meta.fk_fields:
field_object = model._meta.fields_map[field]
field_object = cast(fields.ForeignKeyField, model._meta.fields_map[field])
reference = field_object.model_name
related_app_name, related_model_name = reference.split('.')
related_model = cls.apps[related_app_name][related_model_name]
Expand All @@ -66,45 +67,46 @@ def _init_relations(cls):
related_model._meta.fields.add(backward_relation_name)

for field in model._meta.m2m_fields:
field_object = model._meta.fields_map[field]
if field_object._generated:
field_mobject = cast(fields.ManyToManyField, model._meta.fields_map[field])
if field_mobject._generated:
continue

backward_key = field_object.backward_key
backward_key = field_mobject.backward_key
if not backward_key:
backward_key = '{}_id'.format(model._meta.table)
field_object.backward_key = backward_key
field_mobject.backward_key = backward_key

reference = field_object.model_name
reference = field_mobject.model_name
related_app_name, related_model_name = reference.split('.')
related_model = cls.apps[related_app_name][related_model_name]

field_object.type = related_model
field_mobject.type = related_model

backward_relation_name = field_object.related_name
backward_relation_name = field_mobject.related_name
if not backward_relation_name:
backward_relation_name = '{}s'.format(model._meta.table)
backward_relation_name = field_mobject.related_name = \
'{}_through'.format(model._meta.table)
if backward_relation_name in related_model._meta.fields:
raise ConfigurationError(
'backward relation "{}" duplicates in model {}'.format(
backward_relation_name, related_model_name))

if not field_object.through:
if not field_mobject.through:
related_model_table_name = (
related_model._meta.table
if related_model._meta.table else related_model.__name__.lower()
)

field_object.through = '{}_{}'.format(
field_mobject.through = '{}_{}'.format(
model._meta.table,
related_model_table_name,
)

m2m_relation = fields.ManyToManyField(
'{}.{}'.format(app_name, model_name),
field_object.through,
forward_key=field_object.backward_key,
backward_key=field_object.forward_key,
field_mobject.through,
forward_key=field_mobject.backward_key,
backward_key=field_mobject.forward_key,
related_name=field,
type=model
)
Expand All @@ -114,7 +116,7 @@ def _init_relations(cls):
backward_relation_name,
m2m_relation,
)
model._meta.filters.update(get_m2m_filters(field, field_object))
model._meta.filters.update(get_m2m_filters(field, field_mobject))
related_model._meta.filters.update(
get_m2m_filters(backward_relation_name, m2m_relation)
)
Expand All @@ -124,7 +126,7 @@ def _init_relations(cls):
related_model._meta.fields.add(backward_relation_name)

@classmethod
def _discover_client_class(cls, engine):
def _discover_client_class(cls, engine: str) -> BaseDBAsyncClient:
# Let exception bubble up for transparency
engine_module = importlib.import_module(engine)

Expand Down Expand Up @@ -153,22 +155,22 @@ def _discover_models(cls, models_path, app_label) -> List[Type[Model]]:
return discovered_models

@classmethod
async def _init_connections(cls, connections_config, create_db):
async def _init_connections(cls, connections_config: dict, create_db: bool) -> None:
for name, info in connections_config.items():
if isinstance(info, str):
info = expand_db_url(info)
client_class = cls._discover_client_class(info.get('engine'))
db_params = deepcopy(info['credentials'])
db_params.update({'connection_name': name})
connection = client_class(**db_params)
connection = client_class(**db_params) # type: ignore
if create_db:
await connection.db_create()
await connection.create_connection(with_db=True)
cls._connections[name] = connection
current_transaction_map[name] = ContextVar(name, default=None)

@classmethod
def _init_apps(cls, apps_config):
def _init_apps(cls, apps_config: dict) -> None:
for name, info in apps_config.items():
try:
cls.get_connection(info.get('default_connection', 'default'))
Expand All @@ -193,7 +195,7 @@ def _init_apps(cls, apps_config):
cls._build_initial_querysets()

@classmethod
def _get_config_from_config_file(cls, config_file):
def _get_config_from_config_file(cls, config_file: str) -> dict:
_, extension = os.path.splitext(config_file)
if extension in ('.yml', '.yaml'):
import yaml
Expand All @@ -211,7 +213,7 @@ def _get_config_from_config_file(cls, config_file):
return config

@classmethod
def _build_initial_querysets(cls):
def _build_initial_querysets(cls) -> None:
for app in cls.apps.values():
for model in app.values():
model._meta.generate_filters()
Expand Down Expand Up @@ -318,13 +320,13 @@ async def init(
cls._inited = True

@classmethod
async def close_connections(cls):
async def close_connections(cls) -> None:
for connection in cls._connections.values():
await connection.close()
cls._connections = {}

@classmethod
async def _reset_apps(cls):
async def _reset_apps(cls) -> None:
for app in cls.apps.values():
for model in app.values():
model._meta.default_connection = None
Expand Down Expand Up @@ -358,7 +360,7 @@ async def _drop_databases(cls) -> None:
await cls._reset_apps()


def run_async(coro):
def run_async(coro: Coroutine) -> None:
"""
Simple async runner that cleans up DB connections on exit.
This is meant for simple scripts.
Expand Down
2 changes: 1 addition & 1 deletion tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ async def create_connection(self, with_db: bool) -> None:
))

async def close(self) -> None:
if self._connection:
if self._connection: # pragma: nobranch
await self._connection.close()
self.log.debug(
'Closed connection %s with params: user=%s database=%s host=%s port=%s',
Expand Down
4 changes: 2 additions & 2 deletions tortoise/backends/asyncpg/schema_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@


class AsyncpgSchemaGenerator(BaseSchemaGenerator):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.FIELD_TYPE_MAP.update({fields.JSONField: 'JSONB'})

def _get_primary_key_create_string(self, field_name):
def _get_primary_key_create_string(self, field_name: str) -> str:
return '"{}" SERIAL NOT NULL PRIMARY KEY'.format(field_name)
6 changes: 3 additions & 3 deletions tortoise/backends/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ async def execute_script(self, query: str) -> None:
class ConnectionWrapper:
__slots__ = ('connection', )

def __init__(self, connection):
def __init__(self, connection) -> None:
self.connection = connection

async def __aenter__(self):
return self.connection

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
pass


Expand All @@ -71,7 +71,7 @@ async def __aenter__(self):
await self.start()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
if exc_type:
await self.rollback()
else:
Expand Down
Loading

0 comments on commit 98c37cd

Please sign in to comment.