diff --git a/src/replicate/_module_client.py b/src/replicate/_module_client.py index 817c605..a3e8ab4 100644 --- a/src/replicate/_module_client.py +++ b/src/replicate/_module_client.py @@ -88,13 +88,17 @@ def _run(*args, **kwargs): return _load_client().run(*args, **kwargs) def _use(ref, *, hint=None, streaming=False, use_async=False, **kwargs): + from .lib._predictions_use import use + if use_async: # For async, we need to use AsyncReplicate instead from ._client import AsyncReplicate - client = AsyncReplicate() - return client.use(ref, hint=hint, streaming=streaming, **kwargs) - return _load_client().use(ref, hint=hint, streaming=streaming, **kwargs) + return use(AsyncReplicate, ref, hint=hint, streaming=streaming, **kwargs) + + from ._client import Replicate + + return use(Replicate, ref, hint=hint, streaming=streaming, **kwargs) run = _run use = _use diff --git a/src/replicate/lib/_predictions_use.py b/src/replicate/lib/_predictions_use.py index 606bbee..1cd085c 100644 --- a/src/replicate/lib/_predictions_use.py +++ b/src/replicate/lib/_predictions_use.py @@ -9,6 +9,7 @@ Any, Dict, List, + Type, Tuple, Union, Generic, @@ -436,15 +437,18 @@ class Function(Generic[Input, Output]): A wrapper for a Replicate model that can be called as a function. """ - _client: Client _ref: str _streaming: bool - def __init__(self, client: Client, ref: str, *, streaming: bool) -> None: - self._client = client + def __init__(self, client: Type[Client], ref: str, *, streaming: bool) -> None: + self._client_class = client self._ref = ref self._streaming = streaming + @property + def _client(self) -> Client: + return self._client_class() + def __call__(self, *args: Input.args, **inputs: Input.kwargs) -> Output: return self.create(*args, **inputs).output() @@ -666,16 +670,19 @@ class AsyncFunction(Generic[Input, Output]): An async wrapper for a Replicate model that can be called as a function. """ - _client: AsyncClient _ref: str _streaming: bool _openapi_schema: Optional[Dict[str, Any]] = None - def __init__(self, client: AsyncClient, ref: str, *, streaming: bool) -> None: - self._client = client + def __init__(self, client: Type[AsyncClient], ref: str, *, streaming: bool) -> None: + self._client_class = client self._ref = ref self._streaming = streaming + @property + def _client(self) -> AsyncClient: + return self._client_class() + @cached_property def _parsed_ref(self) -> Tuple[str, str, Optional[str]]: return ModelVersionIdentifier.parse(self._ref) @@ -804,7 +811,7 @@ async def openapi_schema(self) -> Dict[str, Any]: @overload def use( - client: Client, + client: Type[Client], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -814,7 +821,7 @@ def use( @overload def use( - client: Client, + client: Type[Client], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -824,7 +831,7 @@ def use( @overload def use( - client: AsyncClient, + client: Type[AsyncClient], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -834,7 +841,7 @@ def use( @overload def use( - client: AsyncClient, + client: Type[AsyncClient], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, @@ -843,7 +850,7 @@ def use( def use( - client: Union[Client, AsyncClient], + client: Union[Type[Client], Type[AsyncClient]], ref: Union[str, FunctionRef[Input, Output]], *, hint: Optional[Callable[Input, Output]] = None, # pylint: disable=unused-argument # noqa: ARG001 # required for type inference @@ -868,7 +875,7 @@ def use( except AttributeError: pass - if isinstance(client, AsyncClient): + if issubclass(client, AsyncClient): # TODO: Fix type inference for AsyncFunction return type return AsyncFunction(client, str(ref), streaming=streaming) # type: ignore[return-value] diff --git a/tests/test_simple_lazy.py b/tests/test_simple_lazy.py new file mode 100644 index 0000000..312c9fe --- /dev/null +++ b/tests/test_simple_lazy.py @@ -0,0 +1,52 @@ +"""Test lazy client creation in replicate.use().""" + +import os +import sys +from unittest.mock import MagicMock, patch + + +def test_use_does_not_raise_without_token(): + """Test that replicate.use() works even when no API token is available.""" + sys.path.insert(0, "src") + + with patch.dict(os.environ, {}, clear=True): + with patch.dict(sys.modules, {"cog": None}): + import replicate + + # Should not raise an exception + model = replicate.use("test/model") # type: ignore[misc] + assert model is not None + + +def test_cog_current_scope(): + """Test that cog.current_scope().context is read on each client creation.""" + sys.path.insert(0, "src") + + mock_context = MagicMock() + mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-1")] + + mock_scope = MagicMock() + mock_scope.context = mock_context + + mock_cog = MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with patch.dict(os.environ, {}, clear=True): + with patch.dict(sys.modules, {"cog": mock_cog}): + import replicate + + model = replicate.use("test/model") # type: ignore[misc] + + # Access the client property - this should trigger client creation and cog.current_scope call + _ = model._client + + assert mock_cog.current_scope.call_count == 1 + + # Change the token and access client again - should trigger another call + mock_context.items.return_value = [("REPLICATE_API_TOKEN", "test-token-2")] + + # Create a new model to trigger another client creation + model2 = replicate.use("test/model2") # type: ignore[misc] + _ = model2._client + + assert mock_cog.current_scope.call_count == 2