From f4b617a8ab1d1cb02f787e96b65d7c8b67b981ad Mon Sep 17 00:00:00 2001 From: Dmitri Khokhlov Date: Thu, 7 Mar 2024 18:15:21 -0800 Subject: [PATCH] list[Path] support Signed-off-by: Dmitri Khokhlov --- python/cog/server/runner.py | 20 +++++++++++------- .../fixtures/path-list-input-project/cog.yaml | 3 +++ .../path-list-input-project/predict.py | 9 ++++++++ .../test_integration/test_predict.py | 21 +++++++++++++++++++ 4 files changed, 45 insertions(+), 8 deletions(-) create mode 100644 test-integration/test_integration/fixtures/path-list-input-project/cog.yaml create mode 100644 test-integration/test_integration/fixtures/path-list-input-project/predict.py diff --git a/python/cog/server/runner.py b/python/cog/server/runner.py index e1ee0257b9..dac02a851e 100644 --- a/python/cog/server/runner.py +++ b/python/cog/server/runner.py @@ -382,15 +382,19 @@ def _predict( input_dict = initial_prediction["input"] for k, v in input_dict.items(): - if isinstance(v, types.URLPath): - try: + try: + # Check if v is an instance of URLPath + if isinstance(v, types.URLPath): input_dict[k] = v.convert() - except requests.exceptions.RequestException as e: - tb = traceback.format_exc() - event_handler.append_logs(tb) - event_handler.failed(error=str(e)) - log.warn("failed to download url path from input", exc_info=True) - return event_handler.response + # Check if v is a list of URLPath instances + elif isinstance(v, list) and all(isinstance(item, types.URLPath) for item in v): + input_dict[k] = [item.convert() for item in v] + except requests.exceptions.RequestException as e: + tb = traceback.format_exc() + event_handler.append_logs(tb) + event_handler.failed(error=str(e)) + log.warn("Failed to download url path from input", exc_info=True) + return event_handler.response for event in worker.predict(input_dict, poll=0.1): if should_cancel.is_set(): diff --git a/test-integration/test_integration/fixtures/path-list-input-project/cog.yaml b/test-integration/test_integration/fixtures/path-list-input-project/cog.yaml new file mode 100644 index 0000000000..7b6d5d4dce --- /dev/null +++ b/test-integration/test_integration/fixtures/path-list-input-project/cog.yaml @@ -0,0 +1,3 @@ +build: + python_version: "3.11" +predict: "predict.py:Predictor" diff --git a/test-integration/test_integration/fixtures/path-list-input-project/predict.py b/test-integration/test_integration/fixtures/path-list-input-project/predict.py new file mode 100644 index 0000000000..dd2b8c8e77 --- /dev/null +++ b/test-integration/test_integration/fixtures/path-list-input-project/predict.py @@ -0,0 +1,9 @@ +from cog import BasePredictor, Path + +class Predictor(BasePredictor): + def predict(self, paths: list[Path]) -> str: + output_parts = [] # Use a list to collect file contents + for path in paths: + with open(path) as f: + output_parts.append(f.read()) + return "".join(output_parts) diff --git a/test-integration/test_integration/test_predict.py b/test-integration/test_integration/test_predict.py index ab78c8d869..eeebeaa926 100644 --- a/test-integration/test_integration/test_predict.py +++ b/test-integration/test_integration/test_predict.py @@ -229,3 +229,24 @@ def test_predict_many_inputs_with_existing_image(docker_image, tmpdir_factory): capture_output=True, ) assert result.stdout.decode() == "hello default 20 world jpg foo 6\n" + + +def test_predict_path_list_input(tmpdir_factory): + project_dir = Path(__file__).parent / "fixtures/path-list-input-project" + out_dir = pathlib.Path(tmpdir_factory.mktemp("project")) + shutil.copytree(project_dir, out_dir, dirs_exist_ok=True) + with open(out_dir / "1.txt", "w") as fh: + fh.write("test1") + with open(out_dir / "2.txt", "w") as fh: + fh.write("test2") + cmd = ["cog", "predict", "-i", "paths=[\"@1.txt\",\"@2.txt\"]"] + + result = subprocess.run( + cmd, + cwd=out_dir, + check=True, + capture_output=True, + ) + stdout = result.stdout.decode() + assert "test1" in stdout + assert "test2" in stdout