diff --git a/src/replicate/__init__.py b/src/replicate/__init__.py index 90fcf8f..f424c6b 100644 --- a/src/replicate/__init__.py +++ b/src/replicate/__init__.py @@ -43,6 +43,7 @@ from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient from ._utils._logs import setup_logging as _setup_logging +from .lib._predictions_use import get_path_url as get_path_url __all__ = [ "types", @@ -89,6 +90,7 @@ "Model", "Version", "ModelVersionIdentifier", + "get_path_url", ] if not _t.TYPE_CHECKING: @@ -104,6 +106,9 @@ for __name in __all__: if not __name.startswith("__"): try: + # Skip symbols that are imported later from _module_client + if __name in ("run", "use"): + continue __locals[__name].__module__ = "replicate" except (TypeError, AttributeError): # Some of our exported symbols are builtins which we can't set attributes for. diff --git a/src/replicate/lib/_files.py b/src/replicate/lib/_files.py index 7c6f485..ad49a4c 100644 --- a/src/replicate/lib/_files.py +++ b/src/replicate/lib/_files.py @@ -139,6 +139,8 @@ class FileOutput(httpx.SyncByteStream): def __init__(self, url: str, client: Replicate) -> None: self.url = url self._client = client + # Add __url__ attribute for compatibility with get_path_url() + self.__url__ = url def read(self) -> bytes: if self.url.startswith("data:"): @@ -184,6 +186,8 @@ class AsyncFileOutput(httpx.AsyncByteStream): def __init__(self, url: str, client: AsyncReplicate) -> None: self.url = url self._client = client + # Add __url__ attribute for compatibility with get_path_url() + self.__url__ = url async def read(self) -> bytes: if self.url.startswith("data:"): diff --git a/tests/lib/test_get_path_url.py b/tests/lib/test_get_path_url.py new file mode 100644 index 0000000..9d6fda8 --- /dev/null +++ b/tests/lib/test_get_path_url.py @@ -0,0 +1,106 @@ +from pathlib import Path + +import replicate +from replicate.lib._files import FileOutput, AsyncFileOutput +from replicate.lib._predictions_use import URLPath, get_path_url + +# Test token for client instantiation +TEST_TOKEN = "test-bearer-token" + + +def test_get_path_url_with_urlpath(): + """Test get_path_url returns the URL for URLPath instances.""" + url = "https://example.com/test.jpg" + path_proxy = URLPath(url) + + result = get_path_url(path_proxy) + assert result == url + + +def test_get_path_url_with_fileoutput(): + """Test get_path_url returns the URL for FileOutput instances.""" + url = "https://example.com/test.jpg" + file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN)) + + result = get_path_url(file_output) + assert result == url + + +def test_get_path_url_with_async_fileoutput(): + """Test get_path_url returns the URL for AsyncFileOutput instances.""" + url = "https://example.com/test.jpg" + async_file_output = AsyncFileOutput(url, replicate.AsyncReplicate(bearer_token=TEST_TOKEN)) + + result = get_path_url(async_file_output) + assert result == url + + +def test_get_path_url_with_regular_path(): + """Test get_path_url returns None for regular Path instances.""" + regular_path = Path("test.txt") + + result = get_path_url(regular_path) + assert result is None + + +def test_get_path_url_with_object_without_target(): + """Test get_path_url returns None for objects without __url__.""" + + # Test with a string + result = get_path_url("not a path") + assert result is None + + # Test with a dict + result = get_path_url({"key": "value"}) + assert result is None + + # Test with None + result = get_path_url(None) + assert result is None + + +def test_get_path_url_module_level_import(): + """Test that get_path_url can be imported at module level.""" + from replicate import get_path_url as module_get_path_url + + url = "https://example.com/test.jpg" + file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN)) + + result = module_get_path_url(file_output) + assert result == url + + +def test_get_path_url_direct_module_access(): + """Test that get_path_url can be accessed directly from replicate module.""" + url = "https://example.com/test.jpg" + file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN)) + + result = replicate.get_path_url(file_output) + assert result == url + + +def test_fileoutput_has_url_attribute(): + """Test that FileOutput instances have __url__ attribute.""" + url = "https://example.com/test.jpg" + file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN)) + + assert hasattr(file_output, "__url__") + assert file_output.__url__ == url + + +def test_async_fileoutput_has_url_attribute(): + """Test that AsyncFileOutput instances have __url__ attribute.""" + url = "https://example.com/test.jpg" + async_file_output = AsyncFileOutput(url, replicate.AsyncReplicate(bearer_token=TEST_TOKEN)) + + assert hasattr(async_file_output, "__url__") + assert async_file_output.__url__ == url + + +def test_urlpath_has_url_attribute(): + """Test that URLPath instances have __url__ attribute.""" + url = "https://example.com/test.jpg" + url_path = URLPath(url) + + assert hasattr(url_path, "__url__") + assert url_path.__url__ == url