Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added typehint overloads to accurately infer the return type for ray.… (
ray-project#45033) This PR solves issue [ray-project#45032](ray-project#45032) by implementing overloads that handle the return type inference for synchronous and asynchronous functions and methods decorated with `@ray.serve.batch`. The change improves compatibility with pylance, mypy and other type checkers, enhancing user-friendlyness. NOTE: No runtime changes are expected. # Before PR: ```python '''Test the type infernece for return value of a batched function or method with pylance.''' import ray.serve @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue def batch_fn_sync_no_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(integers: list[int]) -> list[int] NOTE: still expecting a list[int] return type @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue async def batch_fn_async_no_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] NOTE: still expecting a list[int] return type @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue def batch_fn_sync_w_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] c = batch_fn_sync_w_args(1) # (function) batch_fn_sync_w_args: G@batch NOTE: G@batch is un-inferred by pylance @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue async def batch_fn_async_w_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] d = batch_fn_async_w_args(1) # (function) batch_fn_async_w_args: G@batch NOTE: G@batch is un-inferred by pylance class Server: @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] e = Server().batch_meth_sync_no_args(1) # Argument of type "Literal[1]" cannot be assigned to parameter "integers" of type "list[int]" in function "batch_meth_sync_no_args" # (method) def batch_meth_sync_no_args(integers: list[int]) -> list[int] f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] g = Server().batch_meth_sync_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue # (method) def batch_meth_sync_w_args(int) -> int h = Server().batch_meth_async_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue # (method) def batch_meth_async_w_args() -> R ``` In summary, before the PR, the typechecker (Pylance in the above case) is unable to match the `@ray.serve.batch` function signature of the functions and methods that it is intended to decorate. This causes the input and return type to be inferred incorrectly: 1) The decorated function input type is inferred as as a list, when it in fact should be a scalar. 2) The decorated function return type is inferred as as a list, when it in fact should be a scalar. 3) For the decorated method, 0 positional input args are expected 4) For the decorated method, the return type is the unbound TypeVar R. # After PR: ```python '''Test the type infernece for return value of a batched function or method with pylance.''' import ray.serve @ray.serve.batch def batch_fn_sync_no_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(int) -> int @ray.serve.batch async def batch_fn_async_no_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(int) -> Coroutine[Any, Any, int] @ray.serve.batch(max_batch_size=2) def batch_fn_sync_w_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] c = batch_fn_sync_w_args(1) # (function) def batch_fn_sync_w_args(int) -> int @ray.serve.batch async def batch_fn_async_w_args(integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] d = batch_fn_async_w_args(1) # (function) def batch_fn_async_w_args(int) -> Coroutine[Any, Any, int] class Server: @ray.serve.batch def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] @ray.serve.batch async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] @ray.serve.batch(max_batch_size=2) def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] @ray.serve.batch(max_batch_size=2) async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]: "docs" return [i * 2 for i in integers] e = Server().batch_meth_sync_no_args(1) # (method) def batch_meth_sync_no_args(int) -> int f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(int) -> Coroutine[Any, Any, int] g = Server().batch_meth_sync_w_args(1) # (method) def batch_meth_sync_w_args(int) -> int h = Server().batch_meth_async_w_args(1) # (method) def batch_meth_async_w_args(int) -> Coroutine[Any, Any, int] ``` <!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? The return-type for a `@ray.serve.batch` decorated function is not inferred correctly, making mypy / pylance (probably other checkers too) lost the function signature for the decorated function/method. <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number Closes ray-project#45032 ## Checks - [x] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [x] I've run `scripts/format.sh` to lint the changes in this PR. - [x] I've included any doc changes needed for https://docs.ray.io/en/master/. - [x] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [x] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [x] Unit tests - [x] Release tests --------- Signed-off-by: Arthur <atte.book@gmail.com> Signed-off-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com>
- Loading branch information