Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for list in inputs #1561

Merged
merged 6 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 53 additions & 14 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
package predict

import (
"encoding/json"
"fmt"
"mime"
"os"
"path/filepath"
"strings"

"github.com/mitchellh/go-homedir"
"github.com/vincent-petithory/dataurl"

"github.com/mitchellh/go-homedir"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import moved, and it should move back :)

"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/util/mime"
)

type Input struct {
String *string
File *string
Array *[]interface{}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small points:

  1. it's nice to use any rather than interface{} these days
  2. this doesn't need to be a pointer -- slices are already pointer types

}

type Inputs map[string]Input
Expand All @@ -25,15 +28,16 @@ func NewInputs(keyVals map[string]string) Inputs {
val := val
if strings.HasPrefix(val, "@") {
val = val[1:]
expandedVal, err := homedir.Expand(val)
if err != nil {
// FIXME: handle this better?
console.Warnf("Error expanding homedir: %s", err)
input[key] = Input{File: &val}
} else if strings.HasPrefix(val, "[") && strings.HasSuffix(val, "]") {
// Handle array of strings
var arr []interface{}
if err := json.Unmarshal([]byte(val), &arr); err != nil {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is JSON unmarshalling right here? That means the strings have to be quoted, which is inconsistent with the other types.

For example, I can do

cog predict hello-world -i text=world

or

cog predict resnet -i image=@input.jpg

but with this implementation, the following will throw an error

cog predict mymodel -i names=[foo,bar]

I think there are two possibilities here. Either:

  1. do the work to allow individual strings to be either quoted or unquoted, so that they're consistent with the rest of the interface, or
  2. change approach to do arrays by just repeating inputs, e.g. cog predict ... -i names=one -i names=two

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I chose repeated inputs. Good idea. Thanks.

// Handle JSON unmarshalling error
console.Warnf("Error parsing array for key '%s': %s", key, err)
} else {
val = expandedVal
input[key] = Input{Array: &arr}
}

input[key] = Input{File: &val}
} else {
input[key] = Input{String: &val}
}
Expand All @@ -55,19 +59,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
return input
}

func (inputs *Inputs) toMap() (map[string]string, error) {
keyVals := map[string]string{}
func (inputs *Inputs) toMap() (map[string]interface{}, error) {
keyVals := map[string]interface{}{}
for key, input := range *inputs {
if input.String != nil {
// Directly assign the string value
keyVals[key] = *input.String
} else if input.File != nil {
content, err := os.ReadFile(*input.File)
// Single file handling: read content and convert to a data URL
dataURL, err := fileToDataURL(*input.File)
if err != nil {
return keyVals, err
}
mimeType := mime.TypeByExtension(filepath.Ext(*input.File))
keyVals[key] = dataurl.New(content, mimeType).String()
keyVals[key] = dataURL
} else if input.Array != nil {
// Handle array, potentially containing file paths
var dataURLs []string
for _, elem := range *input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs = append(dataURLs, dataURL)
} else if ok {
// Directly use the string if it's not a file path
dataURLs = append(dataURLs, str)
}
}
keyVals[key] = dataURLs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a big deal, but it would be more conventional when you know the length of the slice beforehand to do something like

Suggested change
var dataURLs []string
for _, elem := range *input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs = append(dataURLs, dataURL)
} else if ok {
// Directly use the string if it's not a file path
dataURLs = append(dataURLs, str)
}
}
keyVals[key] = dataURLs
dataURLs := make([]string, len(input.Array))
for i, elem := range input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs[i] = dataURL
} else if ok {
// Directly use the string if it's not a file path
dataURLs[i] = str
}
}
keyVals[key] = dataURLs

}
}
return keyVals, nil
}

// Helper function to read file content and convert to a data URL
func fileToDataURL(filePath string) (string, error) {
// Expand home directory if necessary
expandedVal, err := homedir.Expand(filePath)
if err != nil {
return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err)
}

content, err := os.ReadFile(expandedVal)
if err != nil {
return "", err
}
mimeType := mime.TypeByExtension(filepath.Ext(expandedVal))
dataURL := dataurl.New(content, mimeType).String()
return dataURL, nil
}
2 changes: 1 addition & 1 deletion pkg/predict/predictor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type HealthcheckResponse struct {

type Request struct {
// TODO: could this be Inputs?
Input map[string]string `json:"input"`
Input map[string]interface{} `json:"input"`
}

type Response struct {
Expand Down
1 change: 1 addition & 0 deletions python/cog/command/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

This prints a JSON object describing the inputs of the model.
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

Copy link
Contributor Author

@dkhokhlov dkhokhlov Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is required by one of new linters in make lint:
python -m ruff format --check python
and reformatted by:
python -m ruff format python/cog/command/openapi_schema.py

import json

from ..errors import CogError, ConfigDoesNotExist, PredictorNotSet
Expand Down
6 changes: 3 additions & 3 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,9 @@ class PredictionRequest(PredictionBaseModel):
output_file_prefix: t.Optional[str]

webhook: t.Optional[pydantic.AnyHttpUrl]
webhook_events_filter: t.Optional[
t.List[WebhookEvent]
] = WebhookEvent.default_events()
webhook_events_filter: t.Optional[t.List[WebhookEvent]] = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes like this suggest that you don't have your editor set up to do autoformatting the same as script/format. Would you look into that?

Copy link
Contributor Author

@dkhokhlov dkhokhlov Mar 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this change is required by one of new linters in our make lint:
python -m ruff python/cog
and reformatted by:
python -m ruff format python/cog/schema.py

WebhookEvent.default_events()
)

@classmethod
def with_types(cls, input_type: t.Type[t.Any]) -> t.Any:
Expand Down
22 changes: 14 additions & 8 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,21 @@ def _predict(
input_dict = initial_prediction["input"]

for k, v in input_dict.items():
if isinstance(v, types.URLPath):
try:
try:
# Check if v is an instance of URLPath
if isinstance(v, types.URLPath):
input_dict[k] = v.convert()
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(
isinstance(item, types.URLPath) for item in v
):
input_dict[k] = [item.convert() for item in v]
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("Failed to download url path from input", exc_info=True)
return event_handler.response

for event in worker.predict(input_dict, poll=0.1):
if should_cancel.is_set():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from cog import BasePredictor, Path
from cog import BasePredictor, File


class Predictor(BasePredictor):
def predict(self, path: Path) -> str:
with open(path) as f:
return f.read()
def predict(self, file: File) -> str:
content = file.read()
if isinstance(content, bytes):
# Decode bytes to str assuming UTF-8 encoding; adjust if needed
content = content.decode('utf-8')
return content
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
text
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.11"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from cog import BasePredictor, File


class Predictor(BasePredictor):
def predict(self, files: list[File]) -> str:
output_parts = [] # Use a list to collect file contents
for f in files:
# Assuming file content is in bytes, decode to str before appending
content = f.read()
if isinstance(content, bytes):
# Decode bytes to str assuming UTF-8 encoding; adjust if needed
content = content.decode('utf-8')
output_parts.append(content)
return "\n\n".join(output_parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import BasePredictor, Path


class Predictor(BasePredictor):
def predict(self, path: Path) -> str:
with open(path) as f:
return f.read()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.11"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from cog import BasePredictor, Path

class Predictor(BasePredictor):
def predict(self, paths: list[Path]) -> str:
output_parts = [] # Use a list to collect file contents
for path in paths:
with open(path) as f:
output_parts.append(f.read())
return "".join(output_parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
2 changes: 1 addition & 1 deletion test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def predict(self, text: str) -> str:


def test_build_with_model(docker_image):
project_dir = Path(__file__).parent / "fixtures/file-project"
project_dir = Path(__file__).parent / "fixtures/path-project"
subprocess.run(
["cog", "build", "-t", docker_image],
cwd=project_dir,
Expand Down
29 changes: 25 additions & 4 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_predict_takes_int_inputs_and_returns_ints_to_stdout():


def test_predict_takes_file_inputs(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-input-project"
project_dir = Path(__file__).parent / "fixtures/path-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "input.txt", "w") as fh:
Expand All @@ -46,7 +46,7 @@ def test_predict_takes_file_inputs(tmpdir_factory):


def test_predict_writes_files_to_files(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-output-project"
project_dir = Path(__file__).parent / "fixtures/path-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand All @@ -61,7 +61,7 @@ def test_predict_writes_files_to_files(tmpdir_factory):


def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-output-project"
project_dir = Path(__file__).parent / "fixtures/path-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand All @@ -76,7 +76,7 @@ def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory):


def test_predict_writes_multiple_files_to_files(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-list-output-project"
project_dir = Path(__file__).parent / "fixtures/path-list-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand Down Expand Up @@ -229,3 +229,24 @@ def test_predict_many_inputs_with_existing_image(docker_image, tmpdir_factory):
capture_output=True,
)
assert result.stdout.decode() == "hello default 20 world jpg foo 6\n"


def test_predict_path_list_input(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/path-list-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "1.txt", "w") as fh:
fh.write("test1")
with open(out_dir / "2.txt", "w") as fh:
fh.write("test2")
cmd = ["cog", "predict", "-i", "paths=[\"@1.txt\",\"@2.txt\"]"]

result = subprocess.run(
cmd,
cwd=out_dir,
check=True,
capture_output=True,
)
stdout = result.stdout.decode()
assert "test1" in stdout
assert "test2" in stdout