Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for list in inputs #1561

Merged
merged 6 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/python.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,29 @@ class Predictor(BasePredictor):
upscaled_image.save(output_path)
return Path(output_path)
```

## `List`

The List type is also supported in inputs. It can hold any supported type.

Example for **List[Path]**:
```py
class Predictor(BasePredictor):
def predict(self, paths: list[Path]) -> str:
output_parts = [] # Use a list to collect file contents
for path in paths:
with open(path) as f:
output_parts.append(f.read())
return "".join(output_parts)
```
The corresponding cog command:
```bash
$ echo test1 > 1.txt
$ echo test2 > 2.txt
$ cog predict -i paths=@1.txt -i paths=@2.txt
Running prediction...
test1

test2
```
- Note the repeated inputs with the same name "paths" which constitute the list
5 changes: 3 additions & 2 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openap
}

func parseInputFlags(inputs []string) (predict.Inputs, error) {
keyVals := map[string]string{}
keyVals := map[string][]string{}
for _, input := range inputs {
var name, value string

Expand All @@ -305,7 +305,8 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) {
value = value[1 : len(value)-1]
}

keyVals[name] = value
// Append new values to the slice associated with the key
keyVals[name] = append(keyVals[name], value)
}

return predict.NewInputs(keyVals), nil
Expand Down
79 changes: 57 additions & 22 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
package predict

import (
"fmt"
"mime"
"os"
"path/filepath"
"strings"

"github.com/mitchellh/go-homedir"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import moved, and it should move back :)

"github.com/vincent-petithory/dataurl"

"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/util/mime"
)

type Input struct {
String *string
File *string
Array *[]any
}

type Inputs map[string]Input

func NewInputs(keyVals map[string]string) Inputs {
func NewInputs(keyVals map[string][]string) Inputs {
input := Inputs{}
for key, val := range keyVals {
val := val
if strings.HasPrefix(val, "@") {
val = val[1:]
expandedVal, err := homedir.Expand(val)
if err != nil {
// FIXME: handle this better?
console.Warnf("Error expanding homedir: %s", err)
for key, vals := range keyVals {
if len(vals) == 1 {
val := vals[0]
if strings.HasPrefix(val, "@") {
val = val[1:]
input[key] = Input{File: &val}
} else {
val = expandedVal
input[key] = Input{String: &val}
}

input[key] = Input{File: &val}
} else {
input[key] = Input{String: &val}
} else if len(vals) > 1 {
var anyVals = make([]any, len(vals))
for i, v := range vals {
anyVals[i] = v
}
input[key] = Input{Array: &anyVals}
}
}
return input
Expand All @@ -55,19 +55,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
return input
}

func (inputs *Inputs) toMap() (map[string]string, error) {
keyVals := map[string]string{}
func (inputs *Inputs) toMap() (map[string]any, error) {
keyVals := map[string]any{}
for key, input := range *inputs {
if input.String != nil {
// Directly assign the string value
keyVals[key] = *input.String
} else if input.File != nil {
content, err := os.ReadFile(*input.File)
// Single file handling: read content and convert to a data URL
dataURL, err := fileToDataURL(*input.File)
if err != nil {
return keyVals, err
}
mimeType := mime.TypeByExtension(filepath.Ext(*input.File))
keyVals[key] = dataurl.New(content, mimeType).String()
keyVals[key] = dataURL
} else if input.Array != nil {
// Handle array, potentially containing file paths
dataURLs := make([]string, len(*input.Array))
for i, elem := range *input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs[i] = dataURL
} else if ok {
// Directly use the string if it's not a file path
dataURLs[i] = str
}
}
keyVals[key] = dataURLs
}
}
return keyVals, nil
}

// Helper function to read file content and convert to a data URL
func fileToDataURL(filePath string) (string, error) {
// Expand home directory if necessary
expandedVal, err := homedir.Expand(filePath)
if err != nil {
return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err)
}

content, err := os.ReadFile(expandedVal)
if err != nil {
return "", err
}
mimeType := mime.TypeByExtension(filepath.Ext(expandedVal))
dataURL := dataurl.New(content, mimeType).String()
return dataURL, nil
}
2 changes: 1 addition & 1 deletion pkg/predict/predictor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type HealthcheckResponse struct {

type Request struct {
// TODO: could this be Inputs?
Input map[string]string `json:"input"`
Input map[string]interface{} `json:"input"`
}

type Response struct {
Expand Down
1 change: 1 addition & 0 deletions python/cog/command/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

This prints a JSON object describing the inputs of the model.
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Copy link
Contributor Author

@dkhokhlov dkhokhlov Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is required by one of new linters in make lint:
python -m ruff format --check python
and reformatted by:
python -m ruff format python/cog/command/openapi_schema.py

import json

from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet
Expand Down
6 changes: 3 additions & 3 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class PredictionRequest(PredictionBaseModel):
output_file_prefix: t.Optional[str]

webhook: t.Optional[pydantic.AnyHttpUrl]
webhook_events_filter: t.Optional[
t.List[WebhookEvent]
] = WebhookEvent.default_events()
webhook_events_filter: t.Optional[t.List[WebhookEvent]] = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes like this suggest that you don't have your editor set up to do autoformatting the same as script/format. Would you look into that?

Copy link
Contributor Author

@dkhokhlov dkhokhlov Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is required by one of new linters in our make lint:
python -m ruff python/cog
and reformatted by:
python -m ruff format python/cog/schema.py

WebhookEvent.default_events()
)

@classmethod
def with_types(cls, input_type: t.Type[t.Any]) -> t.Any:
Expand Down
22 changes: 14 additions & 8 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,21 @@ def _predict(
input_dict = initial_prediction["input"]

for k, v in input_dict.items():
if isinstance(v, types.URLPath):
try:
try:
# Check if v is an instance of URLPath
if isinstance(v, types.URLPath):
input_dict[k] = v.convert()
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(
isinstance(item, types.URLPath) for item in v
):
input_dict[k] = [item.convert() for item in v]
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("Failed to download url path from input", exc_info=True)
return event_handler.response

for event in worker.predict(input_dict, poll=0.1):
if should_cancel.is_set():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from cog import BasePredictor, Path
from cog import BasePredictor, File


class Predictor(BasePredictor):
def predict(self, path: Path) -> str:
with open(path) as f:
return f.read()
def predict(self, file: File) -> str:
content = file.read()
if isinstance(content, bytes):
# Decode bytes to str assuming UTF-8 encoding; adjust if needed
content = content.decode('utf-8')
return content
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
text
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.11"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from cog import BasePredictor, File


class Predictor(BasePredictor):
def predict(self, files: list[File]) -> str:
output_parts = [] # Use a list to collect file contents
for f in files:
# Assuming file content is in bytes, decode to str before appending
content = f.read()
if isinstance(content, bytes):
# Decode bytes to str assuming UTF-8 encoding; adjust if needed
content = content.decode('utf-8')
output_parts.append(content)
return "\n\n".join(output_parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import BasePredictor, Path


class Predictor(BasePredictor):
def predict(self, path: Path) -> str:
with open(path) as f:
return f.read()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.11"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from cog import BasePredictor, Path

class Predictor(BasePredictor):
def predict(self, paths: list[Path]) -> str:
output_parts = [] # Use a list to collect file contents
for path in paths:
with open(path) as f:
output_parts.append(f.read())
return "".join(output_parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
2 changes: 1 addition & 1 deletion test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def predict(self, text: str) -> str:


def test_build_with_model(docker_image):
project_dir = Path(__file__).parent / "fixtures/file-project"
project_dir = Path(__file__).parent / "fixtures/path-project"
subprocess.run(
["cog", "build", "-t", docker_image],
cwd=project_dir,
Expand Down
29 changes: 25 additions & 4 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_predict_takes_int_inputs_and_returns_ints_to_stdout():


def test_predict_takes_file_inputs(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-input-project"
project_dir = Path(__file__).parent / "fixtures/path-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "input.txt", "w") as fh:
Expand All @@ -46,7 +46,7 @@ def test_predict_takes_file_inputs(tmpdir_factory):


def test_predict_writes_files_to_files(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-output-project"
project_dir = Path(__file__).parent / "fixtures/path-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand All @@ -61,7 +61,7 @@ def test_predict_writes_files_to_files(tmpdir_factory):


def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-output-project"
project_dir = Path(__file__).parent / "fixtures/path-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand All @@ -76,7 +76,7 @@ def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory):


def test_predict_writes_multiple_files_to_files(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-list-output-project"
project_dir = Path(__file__).parent / "fixtures/path-list-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand Down Expand Up @@ -229,3 +229,24 @@ def test_predict_many_inputs_with_existing_image(docker_image, tmpdir_factory):
capture_output=True,
)
assert result.stdout.decode() == "hello default 20 world jpg foo 6\n"


def test_predict_path_list_input(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/path-list-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "1.txt", "w") as fh:
fh.write("test1")
with open(out_dir / "2.txt", "w") as fh:
fh.write("test2")
cmd = ["cog", "predict", "-i", "paths=@1.txt", "-i", "paths=@2.txt"]

result = subprocess.run(
cmd,
cwd=out_dir,
check=True,
capture_output=True,
)
stdout = result.stdout.decode()
assert "test1" in stdout
assert "test2" in stdout