Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added typehint overloads to accurately infer the return type for ray.… #45033

Merged
merged 4 commits into from
May 17, 2024

Conversation

ArthurBook
Copy link
Contributor

@ArthurBook ArthurBook commented Apr 29, 2024

This PR solves issue #45032 by implementing overloads that handle the return type inference for synchronous and asynchronous functions and methods decorated with @ray.serve.batch. The change improves compatibility with pylance, mypy and other type checkers, enhancing user-friendlyness. NOTE: No runtime changes are expected.

Before PR:

'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve


@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(integers: list[int]) -> list[int]  NOTE: still expecting a list[int] return type

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] NOTE: still expecting a list[int] return type


@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) batch_fn_sync_w_args: G@batch NOTE: G@batch is un-inferred by pylance


@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) batch_fn_async_w_args: G@batch NOTE: G@batch is un-inferred by pylance



class Server:
    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]


e = Server().batch_meth_sync_no_args(1)  # Argument of type "Literal[1]" cannot be assigned to parameter "integers" of type "list[int]" in function "batch_meth_sync_no_args"
# (method) def batch_meth_sync_no_args(integers: list[int]) -> list[int]
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
# (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
g = Server().batch_meth_sync_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_async_w_args() -> R

In summary, before the PR, the typechecker (Pylance in the above case) is unable to match the @ray.serve.batch function signature of the functions and methods that it is intended to decorate. This causes the input and return type to be inferred incorrectly:

  1. The decorated function input type is inferred as as a list, when it in fact should be a scalar.
  2. The decorated function return type is inferred as as a list, when it in fact should be a scalar.
  3. For the decorated method, 0 positional input args are expected
  4. For the decorated method, the return type is the unbound TypeVar R.

After PR:

'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve


@ray.serve.batch
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(int) -> int

@ray.serve.batch
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(int) -> Coroutine[Any, Any, int]


@ray.serve.batch(max_batch_size=2)
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) def batch_fn_sync_w_args(int) -> int


@ray.serve.batch
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) def batch_fn_async_w_args(int) -> Coroutine[Any, Any, int]


class Server:
    @ray.serve.batch
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]


e = Server().batch_meth_sync_no_args(1)  # (method) def batch_meth_sync_no_args(int) -> int
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(int) -> Coroutine[Any, Any, int]
g = Server().batch_meth_sync_w_args(1) # (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # (method) def batch_meth_async_w_args(int) -> Coroutine[Any, Any, int]

Why are these changes needed?

The return-type for a @ray.serve.batch decorated function is not inferred correctly, making mypy / pylance (probably other checkers too) lost the function signature for the decorated function/method.

Related issue number

Closes #45032

Checks

  • I've signed off every commit(by using the -s flag, i.e., git commit -s) in this PR.
  • I've run scripts/format.sh to lint the changes in this PR.
  • I've included any doc changes needed for https://docs.ray.io/en/master/.
    • I've added any new APIs to the API Reference. For example, if I added a
      method in Tune, I've added it in doc/source/tune/api/ under the
      corresponding .rst file.
  • I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/
  • Testing Strategy
    • Unit tests
    • Release tests

…serve.batch() on (a)sync functions and methods

Signed-off-by: Arthur <atte.book@gmail.com>
Signed-off-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com>
@GeneDer GeneDer requested review from edoakes, shrekris-anyscale and a team May 2, 2024 15:53
@GeneDer
Copy link
Contributor

GeneDer commented May 2, 2024

@edoakes @shrekris-anyscale can you help to take a look at this when you have a sec

@shrekris-anyscale
Copy link
Contributor

@ArthurBook Thanks for the contribution! Out of curiosity, does this also work if you decorate methods using @serve.batch instead of @ray.serve.batch? E.g.:

from ray import serve

@serve.batch
def my_batch_f(input_list):
    ...

@ArthurBook
Copy link
Contributor Author

Hi, yes, just tried it, works independently of how you import.

    # module import
    from ray import serve

    @serve.batch
    def batch_fn(_: list[int]) -> list[int]: ...
    a = batch_fn(1)  # (function) def batch_fn(int) -> int
    
    # fn import
    from ray.serve import batch
    @batch
    def batch_fn(_: list[int]) -> list[int]: ...
    a = batch_fn(1)  # (function) def batch_fn(int) -> int

@ArthurBook Thanks for the contribution! Out of curiosity, does this also work if you decorate methods using @serve.batch instead of @ray.serve.batch? E.g.:

from ray import serve

@serve.batch
def my_batch_f(input_list):
    ...

@anyscalesam anyscalesam added triage Needs triage (eg: priority, bug/not-bug, and owning component) serve Ray Serve Related Issue labels May 2, 2024
Copy link
Contributor

@shrekris-anyscale shrekris-anyscale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this contribution! I'm a little unfamiliar with some of this typing syntax, so most of my questions are to understand these changes a bit more. Let me know if you want me clarify anything.

python/ray/serve/batching.py Show resolved Hide resolved
@@ -445,33 +449,90 @@ def _validate_batch_wait_timeout_s(batch_wait_timeout_s):
)


SelfType = TypeVar("SelfType", contravariant=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why set contravaraint=True here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is usually done for protocols where the input type is allowed to be any subtype of the given variable. Here is a similar example from pep.

In our case, SelfType is used as a input parameter to the __call__ method of the _BatchDecorator Protocol, so I hinted it to accept any subtype of the SelfType. This means that the decorated method is OK for use also if you subclass _BatchDecorator.

python/ray/serve/batching.py Show resolved Hide resolved
python/ray/serve/batching.py Show resolved Hide resolved
@ArthurBook
Copy link
Contributor Author

Hi @shrekris-anyscale thanks for reviewing! Let me know if you'd like me to elaborate further on any of the points!

Copy link
Contributor

@shrekris-anyscale shrekris-anyscale left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution!

@edoakes edoakes added the go add ONLY when ready to merge, run all tests label May 17, 2024
@edoakes edoakes enabled auto-merge (squash) May 17, 2024 19:36
@edoakes edoakes merged commit 3acce1c into ray-project:master May 17, 2024
8 checks passed
@ArthurBook ArthurBook deleted the ray.serve.batch-types branch May 17, 2024 23:22
ryanaoleary pushed a commit to ryanaoleary/ray that referenced this pull request Jun 6, 2024
ray-project#45033)

This PR solves issue
[ray-project#45032](ray-project#45032) by
implementing overloads that handle the return type inference for
synchronous and asynchronous functions and methods decorated with
`@ray.serve.batch`. The change improves compatibility with pylance, mypy
and other type checkers, enhancing user-friendlyness. NOTE: No runtime
changes are expected.

# Before PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(integers: list[int]) -> list[int]  NOTE: still expecting a list[int] return type

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] NOTE: still expecting a list[int] return type

@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) batch_fn_sync_w_args: G@batch NOTE: G@batch is un-inferred by pylance

@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) batch_fn_async_w_args: G@batch NOTE: G@batch is un-inferred by pylance

class Server:
    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

e = Server().batch_meth_sync_no_args(1)  # Argument of type "Literal[1]" cannot be assigned to parameter "integers" of type "list[int]" in function "batch_meth_sync_no_args"
# (method) def batch_meth_sync_no_args(integers: list[int]) -> list[int]
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
# (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
g = Server().batch_meth_sync_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_async_w_args() -> R

```
In summary, before the PR, the typechecker (Pylance in the above case)
is unable to match the `@ray.serve.batch` function signature of the
functions and methods that it is intended to decorate. This causes the
input and return type to be inferred incorrectly:
1) The decorated function input type is inferred as as a list, when it
in fact should be a scalar.
2) The decorated function return type is inferred as as a list, when it
in fact should be a scalar.
3) For the decorated method, 0 positional input args are expected
4) For the decorated method, the return type is the unbound TypeVar R.

# After PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve

@ray.serve.batch
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(int) -> int

@ray.serve.batch
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(int) -> Coroutine[Any, Any, int]

@ray.serve.batch(max_batch_size=2)
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) def batch_fn_sync_w_args(int) -> int

@ray.serve.batch
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) def batch_fn_async_w_args(int) -> Coroutine[Any, Any, int]

class Server:
    @ray.serve.batch
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

e = Server().batch_meth_sync_no_args(1)  # (method) def batch_meth_sync_no_args(int) -> int
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(int) -> Coroutine[Any, Any, int]
g = Server().batch_meth_sync_w_args(1) # (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # (method) def batch_meth_async_w_args(int) -> Coroutine[Any, Any, int]
```

<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?
The return-type for a `@ray.serve.batch` decorated function is not
inferred correctly, making mypy / pylance (probably other checkers too)
lost the function signature for the decorated function/method.
<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number
Closes ray-project#45032

## Checks
- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [x] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [x] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [x] Release tests

---------

Signed-off-by: Arthur <atte.book@gmail.com>
Signed-off-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
ryanaoleary pushed a commit to ryanaoleary/ray that referenced this pull request Jun 6, 2024
ray-project#45033)

This PR solves issue
[ray-project#45032](ray-project#45032) by
implementing overloads that handle the return type inference for
synchronous and asynchronous functions and methods decorated with
`@ray.serve.batch`. The change improves compatibility with pylance, mypy
and other type checkers, enhancing user-friendlyness. NOTE: No runtime
changes are expected.

# Before PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(integers: list[int]) -> list[int]  NOTE: still expecting a list[int] return type

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] NOTE: still expecting a list[int] return type

@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) batch_fn_sync_w_args: G@batch NOTE: G@batch is un-inferred by pylance

@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) batch_fn_async_w_args: G@batch NOTE: G@batch is un-inferred by pylance

class Server:
    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

e = Server().batch_meth_sync_no_args(1)  # Argument of type "Literal[1]" cannot be assigned to parameter "integers" of type "list[int]" in function "batch_meth_sync_no_args"
# (method) def batch_meth_sync_no_args(integers: list[int]) -> list[int]
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
# (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
g = Server().batch_meth_sync_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_async_w_args() -> R

```
In summary, before the PR, the typechecker (Pylance in the above case)
is unable to match the `@ray.serve.batch` function signature of the
functions and methods that it is intended to decorate. This causes the
input and return type to be inferred incorrectly:
1) The decorated function input type is inferred as as a list, when it
in fact should be a scalar.
2) The decorated function return type is inferred as as a list, when it
in fact should be a scalar.
3) For the decorated method, 0 positional input args are expected
4) For the decorated method, the return type is the unbound TypeVar R.

# After PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve

@ray.serve.batch
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(int) -> int

@ray.serve.batch
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(int) -> Coroutine[Any, Any, int]

@ray.serve.batch(max_batch_size=2)
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) def batch_fn_sync_w_args(int) -> int

@ray.serve.batch
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) def batch_fn_async_w_args(int) -> Coroutine[Any, Any, int]

class Server:
    @ray.serve.batch
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

e = Server().batch_meth_sync_no_args(1)  # (method) def batch_meth_sync_no_args(int) -> int
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(int) -> Coroutine[Any, Any, int]
g = Server().batch_meth_sync_w_args(1) # (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # (method) def batch_meth_async_w_args(int) -> Coroutine[Any, Any, int]
```

<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?
The return-type for a `@ray.serve.batch` decorated function is not
inferred correctly, making mypy / pylance (probably other checkers too)
lost the function signature for the decorated function/method.
<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number
Closes ray-project#45032

## Checks
- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [x] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [x] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [x] Release tests

---------

Signed-off-by: Arthur <atte.book@gmail.com>
Signed-off-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com>
Signed-off-by: Ryan O'Leary <ryanaoleary@google.com>
ryanaoleary pushed a commit to ryanaoleary/ray that referenced this pull request Jun 7, 2024
ray-project#45033)

This PR solves issue
[ray-project#45032](ray-project#45032) by
implementing overloads that handle the return type inference for
synchronous and asynchronous functions and methods decorated with
`@ray.serve.batch`. The change improves compatibility with pylance, mypy
and other type checkers, enhancing user-friendlyness. NOTE: No runtime
changes are expected.

# Before PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve


@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(integers: list[int]) -> list[int]  NOTE: still expecting a list[int] return type

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] NOTE: still expecting a list[int] return type


@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) batch_fn_sync_w_args: G@batch NOTE: G@batch is un-inferred by pylance


@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) batch_fn_async_w_args: G@batch NOTE: G@batch is un-inferred by pylance



class Server:
    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]


e = Server().batch_meth_sync_no_args(1)  # Argument of type "Literal[1]" cannot be assigned to parameter "integers" of type "list[int]" in function "batch_meth_sync_no_args"
# (method) def batch_meth_sync_no_args(integers: list[int]) -> list[int]
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
# (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
g = Server().batch_meth_sync_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_async_w_args() -> R

```
In summary, before the PR, the typechecker (Pylance in the above case)
is unable to match the `@ray.serve.batch` function signature of the
functions and methods that it is intended to decorate. This causes the
input and return type to be inferred incorrectly:
1) The decorated function input type is inferred as as a list, when it
in fact should be a scalar.
2) The decorated function return type is inferred as as a list, when it
in fact should be a scalar.
3) For the decorated method, 0 positional input args are expected
4) For the decorated method, the return type is the unbound TypeVar R.

# After PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve


@ray.serve.batch
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(int) -> int

@ray.serve.batch
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(int) -> Coroutine[Any, Any, int]


@ray.serve.batch(max_batch_size=2)
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) def batch_fn_sync_w_args(int) -> int


@ray.serve.batch
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) def batch_fn_async_w_args(int) -> Coroutine[Any, Any, int]


class Server:
    @ray.serve.batch
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]


e = Server().batch_meth_sync_no_args(1)  # (method) def batch_meth_sync_no_args(int) -> int
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(int) -> Coroutine[Any, Any, int]
g = Server().batch_meth_sync_w_args(1) # (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # (method) def batch_meth_async_w_args(int) -> Coroutine[Any, Any, int]
```

<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?
The return-type for a `@ray.serve.batch` decorated function is not
inferred correctly, making mypy / pylance (probably other checkers too)
lost the function signature for the decorated function/method.
<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number
Closes ray-project#45032

## Checks
- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [x] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [x] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [x] Release tests

---------

Signed-off-by: Arthur <atte.book@gmail.com>
Signed-off-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com>
GabeChurch pushed a commit to GabeChurch/ray that referenced this pull request Jun 11, 2024
ray-project#45033)

This PR solves issue
[ray-project#45032](ray-project#45032) by
implementing overloads that handle the return type inference for
synchronous and asynchronous functions and methods decorated with
`@ray.serve.batch`. The change improves compatibility with pylance, mypy
and other type checkers, enhancing user-friendlyness. NOTE: No runtime
changes are expected.

# Before PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(integers: list[int]) -> list[int]  NOTE: still expecting a list[int] return type

@ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]] NOTE: still expecting a list[int] return type

@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) batch_fn_sync_w_args: G@batch NOTE: G@batch is un-inferred by pylance

@ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) batch_fn_async_w_args: G@batch NOTE: G@batch is un-inferred by pylance

class Server:
    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2) # No overloads for "batch" match the provided argumentsPylancereportCallIssue
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

e = Server().batch_meth_sync_no_args(1)  # Argument of type "Literal[1]" cannot be assigned to parameter "integers" of type "list[int]" in function "batch_meth_sync_no_args"
# (method) def batch_meth_sync_no_args(integers: list[int]) -> list[int]
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
# (method) def batch_meth_async_no_args(integers: list[int]) -> Coroutine[Any, Any, list[int]]
g = Server().batch_meth_sync_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # Expected 0 positional argumentsPylancereportCallIssue
# (method) def batch_meth_async_w_args() -> R

```
In summary, before the PR, the typechecker (Pylance in the above case)
is unable to match the `@ray.serve.batch` function signature of the
functions and methods that it is intended to decorate. This causes the
input and return type to be inferred incorrectly:
1) The decorated function input type is inferred as as a list, when it
in fact should be a scalar.
2) The decorated function return type is inferred as as a list, when it
in fact should be a scalar.
3) For the decorated method, 0 positional input args are expected
4) For the decorated method, the return type is the unbound TypeVar R.

# After PR:
```python
'''Test the type infernece for return value of a batched function or method with pylance.'''
import ray.serve

@ray.serve.batch
def batch_fn_sync_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
a = batch_fn_sync_no_args(1) # (function) def batch_fn_sync_no_args(int) -> int

@ray.serve.batch
async def batch_fn_async_no_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
b = batch_fn_async_no_args(1) # (function) def batch_fn_async_no_args(int) -> Coroutine[Any, Any, int]

@ray.serve.batch(max_batch_size=2)
def batch_fn_sync_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
c = batch_fn_sync_w_args(1) # (function) def batch_fn_sync_w_args(int) -> int

@ray.serve.batch
async def batch_fn_async_w_args(integers: list[int]) -> list[int]:
    "docs"
    return [i * 2 for i in integers]
d = batch_fn_async_w_args(1) # (function) def batch_fn_async_w_args(int) -> Coroutine[Any, Any, int]

class Server:
    @ray.serve.batch
    def batch_meth_sync_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch
    async def batch_meth_async_no_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    def batch_meth_sync_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

    @ray.serve.batch(max_batch_size=2)
    async def batch_meth_async_w_args(self, integers: list[int]) -> list[int]:
        "docs"
        return [i * 2 for i in integers]

e = Server().batch_meth_sync_no_args(1)  # (method) def batch_meth_sync_no_args(int) -> int
f = Server().batch_meth_async_no_args(1) # (method) def batch_meth_async_no_args(int) -> Coroutine[Any, Any, int]
g = Server().batch_meth_sync_w_args(1) # (method) def batch_meth_sync_w_args(int) -> int
h = Server().batch_meth_async_w_args(1) # (method) def batch_meth_async_w_args(int) -> Coroutine[Any, Any, int]
```

<!-- Thank you for your contribution! Please review
https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before
opening a pull request. -->

<!-- Please add a reviewer to the assignee section when you create a PR.
If you don't have the access to it, we will shortly find a reviewer and
assign them to your PR. -->

## Why are these changes needed?
The return-type for a `@ray.serve.batch` decorated function is not
inferred correctly, making mypy / pylance (probably other checkers too)
lost the function signature for the decorated function/method.
<!-- Please give a short summary of the change and the problem this
solves. -->

## Related issue number
Closes ray-project#45032

## Checks
- [x] I've signed off every commit(by using the -s flag, i.e., `git
commit -s`) in this PR.
- [x] I've run `scripts/format.sh` to lint the changes in this PR.
- [x] I've included any doc changes needed for
https://docs.ray.io/en/master/.
- [x] I've added any new APIs to the API Reference. For example, if I
added a
method in Tune, I've added it in `doc/source/tune/api/` under the
           corresponding `.rst` file.
- [x] I've made sure the tests are passing. Note that there might be a
few flaky tests, see the recent failures at https://flakey-tests.ray.io/
- Testing Strategy
   - [x] Unit tests
   - [x] Release tests

---------

Signed-off-by: Arthur <atte.book@gmail.com>
Signed-off-by: Arthur Böök <49250723+ArthurBook@users.noreply.github.com>
Signed-off-by: gchurch <gabe1church@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
go add ONLY when ready to merge, run all tests serve Ray Serve Related Issue triage Needs triage (eg: priority, bug/not-bug, and owning component)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ray.serve.batch return type problems with pylance and mypy
5 participants