Skip to content

Commit

Permalink
fix: make sunray.Actor type var as covariant (#52)
Browse files Browse the repository at this point in the history
  • Loading branch information
zen-xu committed Jun 4, 2024
1 parent 44e7e4e commit ed8598c
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 13 deletions.
28 changes: 15 additions & 13 deletions sunray/_internal/actor_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
_Ret = TypeVar("_Ret")
_YieldItem = TypeVar("_YieldItem")
_RemoteRet = TypeVar("_RemoteRet", bound=io.Out)
_ClassT = TypeVar("_ClassT")
_ClassT_co = TypeVar("_ClassT_co", covariant=True)
_P = ParamSpec("_P")
_R0 = TypeVar("_R0")
_R1 = TypeVar("_R1")
Expand Down Expand Up @@ -76,14 +76,16 @@ def __init__(*args, **kwargs): # noqa: N807 # pragma: no cover
return klass


class ActorClass(Generic[_P, _ClassT]):
def __init__(self, klass: Callable[_P, _ClassT], default_opts: ActorRemoteOptions):
class ActorClass(Generic[_P, _ClassT_co]):
def __init__(
self, klass: Callable[_P, _ClassT_co], default_opts: ActorRemoteOptions
):
self._klass = add_var_keyword_to_klass(klass)
self._default_opts = default_opts

if TYPE_CHECKING:
remote: RemoteCallable[Callable[_P, _ClassT], io.Out[Actor[_ClassT]]]
bind: ClassBindCallable[Callable[_P, _ClassT], io.Actor[_ClassT]]
remote: RemoteCallable[Callable[_P, _ClassT_co], io.Out[Actor[_ClassT_co]]]
bind: ClassBindCallable[Callable[_P, _ClassT_co], io.Actor[_ClassT_co]]
else:

def remote(self, *args, **kwargs):
Expand All @@ -103,19 +105,19 @@ def bind(self, *args, **kwargs):

def options(
self, **opts: Unpack[ActorRemoteOptions]
) -> ActorClassWrapper[_P, _ClassT]:
) -> ActorClassWrapper[_P, _ClassT_co]:
opts = {**self._default_opts, **opts}
return ActorClassWrapper(self._klass, opts)


class ActorClassWrapper(Generic[_P, _ClassT]):
def __init__(self, klass: Callable[_P, _ClassT], opts: ActorRemoteOptions):
class ActorClassWrapper(Generic[_P, _ClassT_co]):
def __init__(self, klass: Callable[_P, _ClassT_co], opts: ActorRemoteOptions):
self._klass = add_var_keyword_to_klass(klass)
self._opts = opts

if TYPE_CHECKING:
remote: RemoteCallable[Callable[_P, _ClassT], io.Out[Actor[_ClassT]]]
bind: ClassBindCallable[Callable[_P, _ClassT], io.Actor[_ClassT]]
remote: RemoteCallable[Callable[_P, _ClassT_co], io.Out[Actor[_ClassT_co]]]
bind: ClassBindCallable[Callable[_P, _ClassT_co], io.Actor[_ClassT_co]]
else:

def remote(self, *args, **kwargs):
Expand Down Expand Up @@ -162,12 +164,12 @@ def bind(self, *args, **kwargs):
return self.actor_method.bind(*args, **kwargs)


class Actor(Generic[_ClassT]):
class Actor(Generic[_ClassT_co]):
def __init__(self, actor_handle: ActorHandle):
self._actor_handle = actor_handle

@property
def methods(self) -> type[_ClassT]: # pragma: no cover
def methods(self) -> type[_ClassT_co]: # pragma: no cover
return ActorHandleProxy(self._actor_handle) # type: ignore[return-value]

def __repr__(self) -> str:
Expand Down Expand Up @@ -835,5 +837,5 @@ def __getattribute__(self, name: str) -> Any: # pragma: no cover
return attr

@classmethod
def new_actor(cls: Callable[_P, _ClassT]) -> ActorClass[_P, _ClassT]:
def new_actor(cls: Callable[_P, _ClassT_co]) -> ActorClass[_P, _ClassT_co]:
return ActorClass(cls, cls._default_ray_opts) # type: ignore[attr-defined]
19 changes: 19 additions & 0 deletions tests/mypy/test_actor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -388,3 +388,22 @@
async def async_stream(self) -> AsyncGenerator[int, None]:
for i in range(10):
yield i
- case: covariant_actor
main: |
import sunray
class Base(sunray.ActorMixin):
...
class Child(Base):
...
def f1(actor: sunray.Actor[Base]): ...
@sunray.remote
def f2(actor: sunray.Actor[Base]): ...
def main(actor: sunray.Actor[Child]):
f1(actor)
f2.remote(actor)

0 comments on commit ed8598c

Please sign in to comment.