Skip to content

Commit

Permalink
list[Path] support
Browse files Browse the repository at this point in the history
Signed-off-by: Dmitri Khokhlov <dkhokhlov@gmail.com>
  • Loading branch information
dkhokhlov committed Mar 8, 2024
1 parent 13e39ca commit f4b617a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 8 deletions.
20 changes: 12 additions & 8 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,19 @@ def _predict(
input_dict = initial_prediction["input"]

for k, v in input_dict.items():
if isinstance(v, types.URLPath):
try:
try:
# Check if v is an instance of URLPath
if isinstance(v, types.URLPath):
input_dict[k] = v.convert()
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(isinstance(item, types.URLPath) for item in v):
input_dict[k] = [item.convert() for item in v]
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("Failed to download url path from input", exc_info=True)
return event_handler.response

for event in worker.predict(input_dict, poll=0.1):
if should_cancel.is_set():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.11"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from cog import BasePredictor, Path

class Predictor(BasePredictor):
def predict(self, paths: list[Path]) -> str:
output_parts = [] # Use a list to collect file contents
for path in paths:
with open(path) as f:
output_parts.append(f.read())
return "".join(output_parts)
21 changes: 21 additions & 0 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,24 @@ def test_predict_many_inputs_with_existing_image(docker_image, tmpdir_factory):
capture_output=True,
)
assert result.stdout.decode() == "hello default 20 world jpg foo 6\n"


def test_predict_path_list_input(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/path-list-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "1.txt", "w") as fh:
fh.write("test1")
with open(out_dir / "2.txt", "w") as fh:
fh.write("test2")
cmd = ["cog", "predict", "-i", "paths=[\"@1.txt\",\"@2.txt\"]"]

result = subprocess.run(
cmd,
cwd=out_dir,
check=True,
capture_output=True,
)
stdout = result.stdout.decode()
assert "test1" in stdout
assert "test2" in stdout

0 comments on commit f4b617a

Please sign in to comment.