Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/replicate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .lib._models import Model as Model, Version as Version, ModelVersionIdentifier as ModelVersionIdentifier
from ._base_client import DefaultHttpxClient, DefaultAioHttpClient, DefaultAsyncHttpxClient
from ._utils._logs import setup_logging as _setup_logging
from .lib._predictions_use import get_path_url as get_path_url

__all__ = [
"types",
Expand Down Expand Up @@ -89,6 +90,7 @@
"Model",
"Version",
"ModelVersionIdentifier",
"get_path_url",
]

if not _t.TYPE_CHECKING:
Expand All @@ -104,6 +106,9 @@
for __name in __all__:
if not __name.startswith("__"):
try:
# Skip symbols that are imported later from _module_client
if __name in ("run", "use"):
continue
__locals[__name].__module__ = "replicate"
except (TypeError, AttributeError):
# Some of our exported symbols are builtins which we can't set attributes for.
Expand Down
4 changes: 4 additions & 0 deletions src/replicate/lib/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ class FileOutput(httpx.SyncByteStream):
def __init__(self, url: str, client: Replicate) -> None:
self.url = url
self._client = client
# Add __url__ attribute for compatibility with get_path_url()
self.__url__ = url

def read(self) -> bytes:
if self.url.startswith("data:"):
Expand Down Expand Up @@ -184,6 +186,8 @@ class AsyncFileOutput(httpx.AsyncByteStream):
def __init__(self, url: str, client: AsyncReplicate) -> None:
self.url = url
self._client = client
# Add __url__ attribute for compatibility with get_path_url()
self.__url__ = url

async def read(self) -> bytes:
if self.url.startswith("data:"):
Expand Down
106 changes: 106 additions & 0 deletions tests/lib/test_get_path_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from pathlib import Path

import replicate
from replicate.lib._files import FileOutput, AsyncFileOutput
from replicate.lib._predictions_use import URLPath, get_path_url

# Test token for client instantiation
TEST_TOKEN = "test-bearer-token"


def test_get_path_url_with_urlpath():
"""Test get_path_url returns the URL for URLPath instances."""
url = "https://example.com/test.jpg"
path_proxy = URLPath(url)

result = get_path_url(path_proxy)
assert result == url


def test_get_path_url_with_fileoutput():
"""Test get_path_url returns the URL for FileOutput instances."""
url = "https://example.com/test.jpg"
file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN))

result = get_path_url(file_output)
assert result == url


def test_get_path_url_with_async_fileoutput():
"""Test get_path_url returns the URL for AsyncFileOutput instances."""
url = "https://example.com/test.jpg"
async_file_output = AsyncFileOutput(url, replicate.AsyncReplicate(bearer_token=TEST_TOKEN))

result = get_path_url(async_file_output)
assert result == url


def test_get_path_url_with_regular_path():
"""Test get_path_url returns None for regular Path instances."""
regular_path = Path("test.txt")

result = get_path_url(regular_path)
assert result is None


def test_get_path_url_with_object_without_target():
"""Test get_path_url returns None for objects without __url__."""

# Test with a string
result = get_path_url("not a path")
assert result is None

# Test with a dict
result = get_path_url({"key": "value"})
assert result is None

# Test with None
result = get_path_url(None)
assert result is None


def test_get_path_url_module_level_import():
"""Test that get_path_url can be imported at module level."""
from replicate import get_path_url as module_get_path_url

url = "https://example.com/test.jpg"
file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN))

result = module_get_path_url(file_output)
assert result == url


def test_get_path_url_direct_module_access():
"""Test that get_path_url can be accessed directly from replicate module."""
url = "https://example.com/test.jpg"
file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN))

result = replicate.get_path_url(file_output)
assert result == url


def test_fileoutput_has_url_attribute():
"""Test that FileOutput instances have __url__ attribute."""
url = "https://example.com/test.jpg"
file_output = FileOutput(url, replicate.Replicate(bearer_token=TEST_TOKEN))

assert hasattr(file_output, "__url__")
assert file_output.__url__ == url


def test_async_fileoutput_has_url_attribute():
"""Test that AsyncFileOutput instances have __url__ attribute."""
url = "https://example.com/test.jpg"
async_file_output = AsyncFileOutput(url, replicate.AsyncReplicate(bearer_token=TEST_TOKEN))

assert hasattr(async_file_output, "__url__")
assert async_file_output.__url__ == url


def test_urlpath_has_url_attribute():
"""Test that URLPath instances have __url__ attribute."""
url = "https://example.com/test.jpg"
url_path = URLPath(url)

assert hasattr(url_path, "__url__")
assert url_path.__url__ == url