Skip to content

Commit

Permalink
Python > 3.6, and various post-review changes
Browse files Browse the repository at this point in the history
  • Loading branch information
aawilson committed Nov 9, 2022
1 parent 1060708 commit e6d2db1
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 37 deletions.
22 changes: 0 additions & 22 deletions .travis.yml

This file was deleted.

34 changes: 24 additions & 10 deletions aiodataloader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

from asyncio import AbstractEventLoop, Future
from asyncio import gather, ensure_future, get_event_loop, iscoroutine, iscoroutinefunction
from collections import namedtuple
Expand All @@ -16,26 +18,38 @@
TypeVar,
Union,
)
from typing_extensions import Protocol, TypeGuard

__version__ = '0.2.1'
if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol

if sys.version_info >= (3, 10):
from typing import TypeGuard
else:
from typing_extensions import TypeGuard

def iscoroutinefunctionorpartial(fn: Callable) -> TypeGuard[Callable[..., Coroutine]]:
return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)
__version__ = '0.2.1'


KeyT = TypeVar("KeyT")
ReturnT = TypeVar("ReturnT")
CacheKeyT = TypeVar("CacheKeyT")
DataLoaderT = TypeVar("DataLoaderT", bound="DataLoader")
DataLoaderT = TypeVar("DataLoaderT", bound="DataLoader[Any, Any]")
T = TypeVar("T")


def iscoroutinefunctionorpartial(
fn: Union[Callable[..., ReturnT], Coroutine[Any, Any, ReturnT]],
) -> TypeGuard[Callable[..., Coroutine[Any, Any, ReturnT]]]:
return iscoroutinefunction(fn.func if isinstance(fn, partial) else fn)


class BatchLoadFnProto(Protocol[KeyT, ReturnT]):
async def __call__(self, keys: List[KeyT]) -> List[ReturnT]:
...


Loader = namedtuple('Loader', 'key,future')


Expand All @@ -52,7 +66,7 @@ def __init__(
max_batch_size: Optional[int] = None,
cache: Optional[bool] = None,
get_cache_key: Optional[Callable[[KeyT], Union[CacheKeyT, KeyT]]] = None,
cache_map: Optional[Dict[Union[CacheKeyT, KeyT], Any]] = None,
cache_map: Optional[Dict[Union[CacheKeyT, KeyT], "Future[ReturnT]"]] = None,
loop: Optional[AbstractEventLoop] = None,
):
self.loop = loop or get_event_loop()
Expand Down Expand Up @@ -188,7 +202,7 @@ def prime(self: DataLoaderT, key: KeyT, value: ReturnT) -> DataLoaderT:
return self


def enqueue_post_future_job(loop: AbstractEventLoop, loader: DataLoader) -> None:
def enqueue_post_future_job(loop: AbstractEventLoop, loader: DataLoader[Any, Any]) -> None:
async def dispatch() -> None:
dispatch_queue(loader)

Expand All @@ -200,7 +214,7 @@ def get_chunks(iterable_obj: List[T], chunk_size: int = 1) -> Iterator[List[T]]:
return (iterable_obj[i:i + chunk_size] for i in range(0, len(iterable_obj), chunk_size))


def dispatch_queue(loader: DataLoader) -> None:
def dispatch_queue(loader: DataLoader[Any, Any]) -> None:
"""
Given the current state of a Loader instance, perform a batch load
from its current queue.
Expand All @@ -224,7 +238,7 @@ def dispatch_queue(loader: DataLoader) -> None:
ensure_future(dispatch_queue_batch(loader, queue))


async def dispatch_queue_batch(loader: DataLoader, queue: List[Loader]) -> None:
async def dispatch_queue_batch(loader: DataLoader[Any, Any], queue: List[Loader]) -> None:
# Collect all keys to be loaded in this dispatch
keys = [ql.key for ql in queue]

Expand Down Expand Up @@ -275,7 +289,7 @@ async def dispatch_queue_batch(loader: DataLoader, queue: List[Loader]) -> None:
return failed_dispatch(loader, queue, e)


def failed_dispatch(loader: DataLoader, queue: List[Loader], error: Exception) -> None:
def failed_dispatch(loader: DataLoader[Any, Any], queue: List[Loader], error: Exception) -> None:
"""
Do not cache individual loads if the entire batch dispatch fails,
but still reject each request so they do not hang.
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ max-line-length = 120
files = aiodataloader.py, test_aiodataloader.py
follow_imports = silent
ignore_missing_imports = True
strict = True
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def get_version(filename):
'Intended Audience :: Developers',
'Topic :: Software Development :: Libraries',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Programming Language :: Python :: 3.8',
Expand All @@ -48,6 +47,7 @@ def get_version(filename):
],
keywords='concurrent future deferred aiodataloader',
py_modules=['aiodataloader'],
python_requires='>=3.6',
extras_require={
'lint': ['flake8', 'mypy'],
'test': tests_require,
Expand Down
7 changes: 3 additions & 4 deletions test_aiodataloader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from collections.abc import Callable, Coroutine
import pytest
from asyncio import gather
from functools import partial
from pytest import raises
from typing import Dict, List, Optional, Tuple, TypeVar
from typing import Any, Callable, Coroutine, Dict, List, Optional, Tuple, TypeVar
from aiodataloader import DataLoader

pytestmark = pytest.mark.asyncio
Expand All @@ -19,7 +18,7 @@ async def do_test():

def id_loader(
*, resolve: Optional[Callable[..., Coroutine]] = None, **dl_kwargs
) -> Tuple[DataLoader, List]:
) -> Tuple[DataLoader[Any, Any], List]:
load_calls = []

async def default_resolve(x: T1) -> T1:
Expand All @@ -29,7 +28,7 @@ async def fn(keys: List) -> List:
load_calls.append(keys)
return await (resolve or default_resolve)(keys)

identity_loader: DataLoader = DataLoader(fn, **dl_kwargs)
identity_loader: DataLoader[Any, Any] = DataLoader(fn, **dl_kwargs)
return identity_loader, load_calls


Expand Down

0 comments on commit e6d2db1

Please sign in to comment.