Skip to content

Commit

Permalink
Use ParamSpec to preserve function signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
henribru committed May 8, 2022
1 parent 0232995 commit 49eec7e
Show file tree
Hide file tree
Showing 11 changed files with 262 additions and 179 deletions.
4 changes: 3 additions & 1 deletion celery-stubs/_state.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

from celery.app.base import Celery
from celery.app.task import Task

current_app: Celery
current_task: Task
current_task: Task[Any, Any]
46 changes: 42 additions & 4 deletions celery-stubs/app/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,55 @@ from celery.app import events as events
from celery.app import task as task
from celery.app.task import Task as Task
from celery.utils.threads import _LocalStack
from typing_extensions import Concatenate, Literal, ParamSpec

_T = TypeVar("_T", bound=Task)
_T = TypeVar("_T", bound=Task[Any, Any])
_P = ParamSpec("_P")
_R = TypeVar("_R")

@overload
def shared_task(fun: Callable[..., Any]) -> Task: ...
def shared_task(fun: Callable[_P, _R]) -> Task[_P, _R]: ...
@overload
def shared_task(
*,
name: str = ...,
serializer: str = ...,
bind: bool = ...,
bind: Literal[False] = ...,
autoretry_for: Tuple[Type[Exception], ...] = ...,
max_retries: int = ...,
default_retry_delay: int = ...,
acks_late: bool = ...,
ignore_result: bool = ...,
soft_time_limit: int = ...,
time_limit: int = ...,
base: None = ...,
retry_kwargs: Dict[str, Any] = ...,
retry_backoff: Union[bool, int] = ...,
retry_backoff_max: int = ...,
retry_jitter: bool = ...,
typing: bool = ...,
rate_limit: Optional[str] = ...,
trail: bool = ...,
send_events: bool = ...,
store_errors_even_if_ignored: bool = ...,
autoregister: bool = ...,
track_started: bool = ...,
acks_on_failure_or_timeout: bool = ...,
reject_on_worker_lost: bool = ...,
throws: Tuple[Type[Exception], ...] = ...,
expires: Optional[Union[float, datetime]] = ...,
priority: Optional[int] = ...,
resultrepr_maxsize: int = ...,
request_stack: _LocalStack = ...,
abstract: bool = ...,
queue: str = ...,
) -> Callable[[Callable[_P, _R]], Task[_P, _R]]: ...
@overload
def shared_task(
*,
name: str = ...,
serializer: str = ...,
bind: Literal[True],
autoretry_for: Tuple[Type[Exception], ...] = ...,
max_retries: int = ...,
default_retry_delay: int = ...,
Expand Down Expand Up @@ -46,7 +84,7 @@ def shared_task(
request_stack: _LocalStack = ...,
abstract: bool = ...,
queue: str = ...,
) -> Callable[[Callable[..., Any]], Task]: ...
) -> Callable[[Callable[Concatenate[Task[_P, _R], _P], _R]], Task[_P, _R]]: ...
@overload
def shared_task(
*,
Expand Down
81 changes: 61 additions & 20 deletions celery-stubs/app/base.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,11 @@ from celery.utils.dispatch import Signal
from celery.utils.objects import FallbackContext
from celery.utils.threads import _LocalStack
from celery.worker import WorkController as CeleryWorkController
from typing_extensions import Concatenate, Literal, ParamSpec

_T = TypeVar("_T", bound=CeleryTask)
_T = TypeVar("_T", bound=CeleryTask[Any, Any])
_P = ParamSpec("_P")
_R = TypeVar("_R")

class Celery:
on_configure: Signal
Expand All @@ -59,16 +62,16 @@ class Celery:
changes: Optional[Dict[str, Any]] = ...,
config_source: Optional[Union[str, object]] = ...,
fixups: Optional[List[str]] = ...,
task_cls: Optional[Union[str, Type[CeleryTask]]] = ...,
task_cls: Optional[Union[str, Type[CeleryTask[Any, Any]]]] = ...,
autofinalize: bool = ...,
namespace: Optional[str] = ...,
strict_typing: bool = ...,
) -> None: ...
def _task_from_fun(
self,
fun: Callable[..., Any],
fun: Callable[_P, _R],
name: Optional[str] = ...,
base: Optional[CeleryTask] = ...,
base: Optional[Type[CeleryTask[Any, Any]]] = ...,
bind: bool = ...,
# options
autoretry_for: Tuple[Type[Exception], ...] = ...,
Expand Down Expand Up @@ -99,15 +102,15 @@ class Celery:
resultrepr_maxsize: int = ...,
request_stack: _LocalStack = ...,
abstract: bool = ...,
) -> CeleryTask: ...
) -> CeleryTask[_P, _R]: ...
def on_init(self) -> None: ...
def set_current(self) -> None: ...
def set_default(self) -> None: ...
def close(self) -> None: ...
def start(self, argv: Optional[List[str]] = ...) -> NoReturn: ...
def worker_main(self, argv: Optional[List[str]] = ...) -> NoReturn: ...
@overload
def task(self, fun: Callable[..., Any]) -> CeleryTask: ...
def task(self, fun: Callable[_P, _R]) -> CeleryTask[_P, _R]: ...
@overload
def task(
self,
Expand Down Expand Up @@ -150,7 +153,43 @@ class Celery:
*,
name: str = ...,
serializer: str = ...,
bind: bool = ...,
bind: Literal[False] = ...,
autoretry_for: Tuple[Type[Exception], ...] = ...,
max_retries: int = ...,
default_retry_delay: int = ...,
acks_late: bool = ...,
ignore_result: bool = ...,
soft_time_limit: int = ...,
time_limit: int = ...,
base: None = ...,
retry_kwargs: Dict[str, Any] = ...,
retry_backoff: Union[bool, int] = ...,
retry_backoff_max: int = ...,
retry_jitter: bool = ...,
typing: bool = ...,
rate_limit: Optional[str] = ...,
trail: bool = ...,
send_events: bool = ...,
store_errors_even_if_ignored: bool = ...,
autoregister: bool = ...,
track_started: bool = ...,
acks_on_failure_or_timeout: bool = ...,
reject_on_worker_lost: bool = ...,
throws: Tuple[Type[Exception], ...] = ...,
expires: Optional[Union[float, datetime.datetime]] = ...,
priority: Optional[int] = ...,
resultrepr_maxsize: int = ...,
request_stack: _LocalStack = ...,
abstract: bool = ...,
queue: str = ...,
) -> Callable[[Callable[_P, _R]], CeleryTask[_P, _R]]: ...
@overload
def task(
self,
*,
name: str = ...,
serializer: str = ...,
bind: Literal[True],
autoretry_for: Tuple[Type[Exception], ...] = ...,
max_retries: int = ...,
default_retry_delay: int = ...,
Expand Down Expand Up @@ -179,12 +218,14 @@ class Celery:
request_stack: _LocalStack = ...,
abstract: bool = ...,
queue: str = ...,
) -> Callable[[Callable[..., Any]], CeleryTask]: ...
) -> Callable[
[Callable[Concatenate[CeleryTask[_P, _R], _P], _R]], CeleryTask[_P, _R]
]: ...
def register_task(
self,
task: Union[CeleryTask, Type[CeleryTask]],
task: Union[CeleryTask[Any, Any], Type[CeleryTask[Any, Any]]],
**options: Any,
) -> CeleryTask: ...
) -> CeleryTask[Any, Any]: ...
def gen_task_name(self, name: str, module: object) -> str: ...
def finalize(self, auto: bool = ...) -> None: ...
def add_defaults(self, fun: Callable[[], Dict[str, Any]]) -> None: ...
Expand Down Expand Up @@ -225,11 +266,11 @@ class Celery:
producer: Optional[kombu.Producer] = ...,
connection: Optional[kombu.Connection] = ...,
router: Optional[Router] = ...,
result_cls: Optional[Type[celery.result.AsyncResult]] = ...,
result_cls: Optional[Type[celery.result.AsyncResult[Any]]] = ...,
expires: Optional[Union[float, datetime.datetime]] = ...,
publisher: Optional[kombu.Producer] = ...,
link: Optional[Signature] = ...,
link_error: Optional[Signature] = ...,
link: Optional[Signature[Any]] = ...,
link_error: Optional[Signature[Any]] = ...,
add_to_parent: bool = ...,
group_id: Optional[str] = ...,
retries: int = ...,
Expand All @@ -246,7 +287,7 @@ class Celery:
# options
ignore_result: bool = ...,
**options: Any,
) -> celery.result.AsyncResult: ...
) -> celery.result.AsyncResult[Any]: ...
def connection_for_read(
self, url: Optional[str] = ..., **kwargs: Any
) -> kombu.Connection: ...
Expand Down Expand Up @@ -283,11 +324,11 @@ class Celery:
def select_queues(self, queues: Optional[Sequence[str]] = ...) -> None: ...
def either(self, default_key: str, *defaults: Any) -> Any: ...
def bugreport(self) -> str: ...
def signature(self, *args: Any, **kwargs: Any) -> Signature: ...
def signature(self, *args: Any, **kwargs: Any) -> Signature[Any]: ...
def add_periodic_task(
self,
schedule: BaseSchedule | float,
sig: Signature,
sig: Signature[Any],
args: Tuple[Any, ...] = ...,
kwargs: Dict[str, Any] = ...,
name: Optional[str] = ...,
Expand All @@ -302,21 +343,21 @@ class Celery:
@property
def Beat(self) -> CeleryBeat: ...
@property
def Task(self) -> CeleryTask: ...
def Task(self) -> CeleryTask[Any, Any]: ...
@property
def annotations(self) -> List[Dict[str, Any]]: ...
@property
def AsyncResult(self) -> celery.result.AsyncResult: ...
def AsyncResult(self) -> celery.result.AsyncResult[Any]: ...
@property
def ResultSet(self) -> celery.result.ResultSet: ...
@property
def GroupResult(self) -> celery.result.GroupResult: ...
@property
def pool(self) -> kombu.pools.ProducerPool: ...
@property
def current_task(self) -> Optional[CeleryTask]: ...
def current_task(self) -> Optional[CeleryTask[Any, Any]]: ...
@property
def current_worker_task(self) -> Optional[CeleryTask]: ...
def current_worker_task(self) -> Optional[CeleryTask[Any, Any]]: ...
@property
def oid(self) -> str: ...
@property
Expand Down

0 comments on commit 49eec7e

Please sign in to comment.