diff --git a/src/replicate/lib/_files.py b/src/replicate/lib/_files.py index c14a944..7c6f485 100644 --- a/src/replicate/lib/_files.py +++ b/src/replicate/lib/_files.py @@ -51,9 +51,8 @@ def encode_json( if file_encoding_strategy == "base64": return base64_encode_file(obj) else: - # todo: support files endpoint - # return client.files.create(obj).urls["get"] - raise NotImplementedError("File upload is not supported yet") + response = client.files.create(content=obj.read()) + return response.urls.get if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) @@ -91,9 +90,8 @@ async def async_encode_json( # TODO: This should ideally use an async based file reader path. return base64_encode_file(obj) else: - # todo: support files endpoint - # return (await client.files.async_create(obj)).urls["get"] - raise NotImplementedError("File upload is not supported yet") + response = await client.files.create(content=obj.read()) + return response.urls.get if HAS_NUMPY: if isinstance(obj, np.integer): # type: ignore return int(obj) diff --git a/tests/lib/test_run.py b/tests/lib/test_run.py index 43df10d..168447d 100644 --- a/tests/lib/test_run.py +++ b/tests/lib/test_run.py @@ -10,9 +10,11 @@ from respx import MockRouter from replicate import Replicate, AsyncReplicate +from replicate._compat import model_dump from replicate.lib._files import FileOutput, AsyncFileOutput from replicate._exceptions import ModelError, NotFoundError, BadRequestError from replicate.lib._models import Model, Version, ModelVersionIdentifier +from replicate.types.file_create_response import URLs, Checksums, FileCreateResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") bearer_token = "My Bearer Token" @@ -89,6 +91,16 @@ class TestRun: # Common model reference format that will work with the new SDK model_ref = "owner/name:version" + file_create_response = FileCreateResponse( + id="test_file_id", + checksums=Checksums(sha256="test_sha256"), + content_type="application/octet-stream", + created_at=datetime.datetime.now(), + expires_at=datetime.datetime.now() + datetime.timedelta(days=1), + metadata={}, + size=1234, + urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"), + ) @pytest.mark.respx(base_url=base_url) def test_run_basic(self, respx_mock: MockRouter) -> None: @@ -236,6 +248,23 @@ def test_run_with_base64_file(self, respx_mock: MockRouter) -> None: assert output == "test output" + @pytest.mark.respx(base_url=base_url) + def test_run_with_file_upload(self, respx_mock: MockRouter) -> None: + """Test run with base64 encoded file input.""" + # Create a simple file-like object + file_obj = io.BytesIO(b"test content") + + # Mock the prediction response + respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction())) + # Mock the file upload endpoint + respx_mock.post("/files").mock( + return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json")) + ) + + output: Any = self.client.run(self.model_ref, input={"file": file_obj}) + + assert output == "test output" + def test_run_with_prefer_conflict(self) -> None: """Test run with conflicting wait and prefer parameters.""" with pytest.raises(TypeError, match="cannot mix and match prefer and wait"): @@ -349,6 +378,16 @@ class TestAsyncRun: # Common model reference format that will work with the new SDK model_ref = "owner/name:version" + file_create_response = FileCreateResponse( + id="test_file_id", + checksums=Checksums(sha256="test_sha256"), + content_type="application/octet-stream", + created_at=datetime.datetime.now(), + expires_at=datetime.datetime.now() + datetime.timedelta(days=1), + metadata={}, + size=1234, + urls=URLs(get="https://api.replicate.com/v1/files/test_file_id"), + ) @pytest.mark.respx(base_url=base_url) async def test_async_run_basic(self, respx_mock: MockRouter) -> None: @@ -501,6 +540,23 @@ async def test_async_run_with_base64_file(self, respx_mock: MockRouter) -> None: assert output == "test output" + @pytest.mark.respx(base_url=base_url) + async def test_async_run_with_file_upload(self, respx_mock: MockRouter) -> None: + """Test async run with base64 encoded file input.""" + # Create a simple file-like object + file_obj = io.BytesIO(b"test content") + + # Mock the prediction response + respx_mock.post("/predictions").mock(return_value=httpx.Response(201, json=create_mock_prediction())) + # Mock the file upload endpoint + respx_mock.post("/files").mock( + return_value=httpx.Response(201, json=model_dump(self.file_create_response, mode="json")) + ) + + output: Any = await self.client.run(self.model_ref, input={"file": file_obj}) + + assert output == "test output" + async def test_async_run_with_prefer_conflict(self) -> None: """Test async run with conflicting wait and prefer parameters.""" with pytest.raises(TypeError, match="cannot mix and match prefer and wait"):