Skip to content

Commit

Permalink
added typehint overloads to accurately infer the return type for ray.… (
Browse files Browse the repository at this point in the history
#45033)

This PR solves issue
[#45032](#45032) by
implementing overloads that handle the return type inference for
synchronous and asynchronous functions and methods decorated with
`@ray.serve.batch`. The change improves compatibility with pylance, mypy
and other type checkers, enhancing user-friendlyness. NOTE: No runtime
changes are expected.

# Before PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve


@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(integers: list[int]) -> list[int]  NOTE: still expecting a list[int] return type

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] NOTE: still expecting a list[int] return type


@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) batch_fn_sync_w_args: G@batch NOTE: G@batch is un-inferred by pylance


@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) batch_fn_async_w_args: G@batch NOTE: G@batch is un-inferred by pylance



class Server:
    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]


e = Server().batch_meth_sync_no_args(1)  # Argument of type "Literal[1]" cannot be assigned to parameter "integers" of type "list[int]" in function "batch_meth_sync_no_args"
# (method) def batch_meth_sync_no_args(integers: list[int]) -> list[int]
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
# (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
g = Server().batch_meth_sync_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_async_w_args() -> R

```
In summary, before the PR, the typechecker (Pylance in the above case)
is unable to match the `@ray.serve.batch` function signature of the
functions and methods that it is intended to decorate. This causes the
input and return type to be inferred incorrectly:
1) The decorated function input type is inferred as as a list, when it
in fact should be a scalar.
2) The decorated function return type is inferred as as a list, when it
in fact should be a scalar.
3) For the decorated method, 0 positional input args are expected
4) For the decorated method, the return type is the unbound TypeVar R.

# After PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve


@ray.serve.batch
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(int) -> int

@ray.serve.batch
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(int) -> Coroutine[Any, Any, int]


@ray.serve.batch(max_batch_size=2)
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) def batch_fn_sync_w_args(int) -> int


@ray.serve.batch
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) def batch_fn_async_w_args(int) -> Coroutine[Any, Any, int]


class Server:
    @ray.serve.batch
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]


e = Server().batch_meth_sync_no_args(1)  # (method) def batch_meth_sync_no_args(int) -> int
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(int) -> Coroutine[Any, Any, int]
g = Server().batch_meth_sync_w_args(1) # (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # (method) def batch_meth_async_w_args(int) -> Coroutine[Any, Any, int]
```

<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?
The return-type for a `@ray.serve.batch` decorated function is not
inferred correctly, making mypy / pylance (probably other checkers too)
lost the function signature for the decorated function/method.
<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number
Closes #45032

## Checks
- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [x] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [x] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [x] Release tests

---------

Signed-off-by: Arthur <atte.book@gmail.com>
Signed-off-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com>
  • Loading branch information
ArthurBook committed May 17, 2024
1 parent e0ac723 commit 3acce1c
Showing 1 changed file with 72 additions and 11 deletions.
83 changes: 72 additions & 11 deletions python/ray/serve/batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@
Any,
AsyncGenerator,
Callable,
Coroutine,
Dict,
Generic,
Iterable,
List,
Literal,
Optional,
Protocol,
Tuple,
TypeVar,
overload,
Expand Down Expand Up @@ -445,33 +449,90 @@ def _validate_batch_wait_timeout_s(batch_wait_timeout_s):
)


SelfType = TypeVar("SelfType", contravariant=True)
T = TypeVar("T")
R = TypeVar("R")
F = TypeVar("F", Callable[[List[T]], List[R]], Callable[[Any, List[T]], List[R]])
G = TypeVar("G", bound=Callable[[T], R])


# Normal decorator use case (called with no arguments).
@overload
def batch(func: F) -> G:
pass
class _SyncBatchingMethod(Protocol, Generic[SelfType, T, R]):
def __call__(self, self_: SelfType, __batch: List[T], /) -> List[R]:
...


# "Decorator factory" use case (called with arguments).
@overload
class _AsyncBatchingMethod(Protocol, Generic[SelfType, T, R]):
async def __call__(self, self_: SelfType, __batch: List[T], /) -> List[R]:
...


@overload # Sync function for `batch` called WITHOUT arguments
def batch(_sync_func: Callable[[List[T]], List[R]], /) -> Callable[[T], R]:
...


@overload # Async function for `batch` called WITHOUT arguments
def batch(
_async_func: Callable[[List[T]], Coroutine[Any, Any, List[R]]], /
) -> Callable[[T], Coroutine[Any, Any, R]]:
...


@overload # Sync method for `batch` called WITHOUT arguments
def batch(
_sync_meth: _SyncBatchingMethod[SelfType, T, R], /
) -> Callable[[SelfType, T], R]:
...


@overload # Async method for `batch` called WITHOUT arguments
def batch(
_async_meth: _AsyncBatchingMethod[SelfType, T, R], /
) -> Callable[[SelfType, T], Coroutine[Any, Any, R]]:
...


@overload # `batch` called WITH arguments
def batch(
_: Literal[None] = None,
/,
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.0,
) -> Callable[[F], G]:
pass
) -> "_BatchDecorator":
...


class _BatchDecorator(Protocol):
"""Descibes behaviour of decorator produced by calling `batch` with arguments"""

@overload # Sync function
def __call__(self, _sync_func: Callable[[List[T]], List[R]], /) -> Callable[[T], R]:
...

@overload # Async function
def __call__(
self, _async_func: Callable[[List[T]], Coroutine[Any, Any, List[R]]], /
) -> Callable[[T], Coroutine[Any, Any, R]]:
...

@overload # Sync method
def __call__(
self, _sync_meth: _SyncBatchingMethod[SelfType, T, R], /
) -> Callable[[SelfType, T], R]:
...

@overload # Async method
def __call__(
self, _async_meth: _AsyncBatchingMethod[SelfType, T, R], /
) -> Callable[[SelfType, T], Coroutine[Any, Any, R]]:
...


@PublicAPI(stability="stable")
def batch(
_func: Optional[Callable] = None,
/,
max_batch_size: int = 10,
batch_wait_timeout_s: float = 0.0,
):
) -> Callable:
"""Converts a function to asynchronously handle batches.
The function can be a standalone function or a class method. In both
Expand Down

0 comments on commit 3acce1c

Please sign in to comment.