diff --git a/src/replicate/_client.py b/src/replicate/_client.py index 7edeb34..7113e54 100644 --- a/src/replicate/_client.py +++ b/src/replicate/_client.py @@ -20,6 +20,7 @@ import httpx +from replicate.lib.cog import _get_api_token_from_environment from replicate.lib._files import FileEncodingStrategy from replicate.lib._predictions_run import Model, Version, ModelVersionIdentifier from replicate.types.prediction_create_params import PredictionCreateParamsWithoutVersion @@ -108,7 +109,7 @@ def __init__( This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: - bearer_token = os.environ.get("REPLICATE_API_TOKEN") + bearer_token = _get_api_token_from_environment() if bearer_token is None: raise ReplicateError( "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" @@ -419,7 +420,7 @@ def __init__( This automatically infers the `bearer_token` argument from the `REPLICATE_API_TOKEN` environment variable if it is not provided. """ if bearer_token is None: - bearer_token = os.environ.get("REPLICATE_API_TOKEN") + bearer_token = _get_api_token_from_environment() if bearer_token is None: raise ReplicateError( "The bearer_token client option must be set either by passing bearer_token to the client or by setting the REPLICATE_API_TOKEN environment variable" diff --git a/src/replicate/lib/cog.py b/src/replicate/lib/cog.py new file mode 100644 index 0000000..635bf09 --- /dev/null +++ b/src/replicate/lib/cog.py @@ -0,0 +1,41 @@ +"""Cog integration utilities for Replicate.""" + +import os +from typing import Any, Union, Iterator, cast + +from replicate._utils._logs import logger + + +def _get_api_token_from_environment() -> Union[str, None]: + """Get API token from cog current scope if available, otherwise from environment.""" + try: + import cog # type: ignore[import-untyped, import-not-found] + + # Get the current scope - this might return None or raise an exception + scope = getattr(cog, "current_scope", lambda: None)() + if scope is None: + return os.environ.get("REPLICATE_API_TOKEN") + + # Get the context from the scope + context = getattr(scope, "context", None) + if context is None: + return os.environ.get("REPLICATE_API_TOKEN") + + # Get the items method and call it + items_method = getattr(context, "items", None) + if not callable(items_method): + return os.environ.get("REPLICATE_API_TOKEN") + + # Iterate through context items looking for the API token + items = cast(Iterator["tuple[Any, Any]"], items_method()) + for key, value in items: + if str(key).upper() == "REPLICATE_API_TOKEN": + return str(value) if value is not None else value + + except Exception as e: # Catch all exceptions to ensure robust fallback + logger.debug("Failed to retrieve API token from cog.current_scope(): %s", e) + + return os.environ.get("REPLICATE_API_TOKEN") + + +__all__ = ["_get_api_token_from_environment"] diff --git a/src/replicate/types/prediction_create_params.py b/src/replicate/types/prediction_create_params.py index 402db8b..8665e35 100644 --- a/src/replicate/types/prediction_create_params.py +++ b/src/replicate/types/prediction_create_params.py @@ -35,7 +35,7 @@ class PredictionCreateParamsWithoutVersion(TypedDict, total=False): - you don't want to upload and host the file somewhere - you don't need to use the file again (Replicate will not store it) """ - + stream: bool """**This field is deprecated.** diff --git a/tests/test_current_scope.py b/tests/test_current_scope.py new file mode 100644 index 0000000..98ae8b2 --- /dev/null +++ b/tests/test_current_scope.py @@ -0,0 +1,220 @@ +"""Tests for current_scope token functionality.""" + +import os +import sys +from unittest import mock + +import pytest + +from replicate import Replicate, AsyncReplicate +from replicate.lib.cog import _get_api_token_from_environment +from replicate._exceptions import ReplicateError + + +class TestGetApiTokenFromEnvironment: + """Test the _get_api_token_from_environment function.""" + + def test_cog_no_current_scope_method_falls_back_to_env(self): + """Test fallback when cog exists but has no current_scope method.""" + mock_cog = mock.MagicMock() + del mock_cog.current_scope # Remove the method + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_current_scope_returns_none_falls_back_to_env(self): + """Test fallback when current_scope() returns None.""" + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = None + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_no_context_attr_falls_back_to_env(self): + """Test fallback when scope has no context attribute.""" + mock_scope = mock.MagicMock() + del mock_scope.context # Remove the context attribute + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_context_not_dict_falls_back_to_env(self): + """Test fallback when scope.context is not a dictionary.""" + mock_scope = mock.MagicMock() + mock_scope.context = "not a dict" + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_no_replicate_api_token_key_falls_back_to_env(self): + """Test fallback when replicate_api_token key is missing from context.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"other_key": "other_value"} # Missing replicate_api_token + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_cog_scope_replicate_api_token_valid_string(self): + """Test successful retrieval of non-empty token from cog.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "cog-token" + + def test_cog_scope_replicate_api_token_case_insensitive(self): + """Test successful retrieval of non-empty token from cog ignoring case.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"replicate_api_token": "cog-token"} + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "cog-token" + + def test_cog_scope_replicate_api_token_empty_string(self): + """Test that empty string from cog is returned (not falling back to env).""" + mock_scope = mock.MagicMock() + mock_scope.context = {"replicate_api_token": ""} # Empty string + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "" # Should return empty string, not env token + + def test_cog_scope_replicate_api_token_none(self): + """Test that None from cog is returned (not falling back to env).""" + mock_scope = mock.MagicMock() + mock_scope.context = {"replicate_api_token": None} + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token is None # Should return None, not env token + + def test_cog_current_scope_raises_exception_falls_back_to_env(self): + """Test fallback when current_scope() raises an exception.""" + mock_cog = mock.MagicMock() + mock_cog.current_scope.side_effect = RuntimeError("Scope error") + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + token = _get_api_token_from_environment() + assert token == "env-token" + + def test_no_env_token_returns_none(self): + """Test that None is returned when no environment token is set and cog unavailable.""" + with mock.patch.dict(os.environ, {}, clear=True): # Clear all env vars + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token is None + + def test_env_token_empty_string(self): + """Test that empty string from environment is returned.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": ""}): + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token == "" + + def test_env_token_valid_string(self): + """Test that valid token from environment is returned.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": None}): + token = _get_api_token_from_environment() + assert token == "env-token" + + +class TestClientCurrentScopeIntegration: + """Test that the client uses current_scope functionality.""" + + def test_sync_client_uses_current_scope_token(self): + """Test that sync client retrieves token from current_scope.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + # Clear environment variable to ensure we're using cog + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + client = Replicate(base_url="http://test.example.com") + assert client.bearer_token == "cog-token" + + def test_async_client_uses_current_scope_token(self): + """Test that async client retrieves token from current_scope.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + # Clear environment variable to ensure we're using cog + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + client = AsyncReplicate(base_url="http://test.example.com") + assert client.bearer_token == "cog-token" + + def test_sync_client_falls_back_to_env_when_cog_unavailable(self): + """Test that sync client falls back to env when cog is unavailable.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": None}): + client = Replicate(base_url="http://test.example.com") + assert client.bearer_token == "env-token" + + def test_async_client_falls_back_to_env_when_cog_unavailable(self): + """Test that async client falls back to env when cog is unavailable.""" + with mock.patch.dict(os.environ, {"REPLICATE_API_TOKEN": "env-token"}): + with mock.patch.dict(sys.modules, {"cog": None}): + client = AsyncReplicate(base_url="http://test.example.com") + assert client.bearer_token == "env-token" + + def test_sync_client_raises_error_when_no_token_available(self): + """Test that sync client raises error when no token is available.""" + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch.dict(sys.modules, {"cog": None}): + with pytest.raises(ReplicateError, match="bearer_token client option must be set"): + Replicate(base_url="http://test.example.com") + + def test_async_client_raises_error_when_no_token_available(self): + """Test that async client raises error when no token is available.""" + with mock.patch.dict(os.environ, {}, clear=True): + with mock.patch.dict(sys.modules, {"cog": None}): + with pytest.raises(ReplicateError, match="bearer_token client option must be set"): + AsyncReplicate(base_url="http://test.example.com") + + def test_explicit_token_overrides_current_scope(self): + """Test that explicitly provided token overrides current_scope.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + client = Replicate(bearer_token="explicit-token", base_url="http://test.example.com") + assert client.bearer_token == "explicit-token" + + def test_explicit_async_token_overrides_current_scope(self): + """Test that explicitly provided token overrides current_scope for async client.""" + mock_scope = mock.MagicMock() + mock_scope.context = {"REPLICATE_API_TOKEN": "cog-token"} + mock_cog = mock.MagicMock() + mock_cog.current_scope.return_value = mock_scope + + with mock.patch.dict(sys.modules, {"cog": mock_cog}): + client = AsyncReplicate(bearer_token="explicit-token", base_url="http://test.example.com") + assert client.bearer_token == "explicit-token"