Skip to content

Commit

Permalink
feat: support for list in inputs (#1561)
Browse files Browse the repository at this point in the history
- added support for list[File] and list[Path];
- only in cog, still need to add list support in web UI;
- updated docs/python.md;
  • Loading branch information
dkhokhlov committed Mar 25, 2024
1 parent cffd61b commit 1289b6f
Show file tree
Hide file tree
Showing 21 changed files with 174 additions and 42 deletions.
26 changes: 26 additions & 0 deletions docs/python.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,29 @@ class Predictor(BasePredictor):
upscaled_image.save(output_path)
return Path(output_path)
```

## `List`

The List type is also supported in inputs. It can hold any supported type.

Example for **List[Path]**:
```py
class Predictor(BasePredictor):
def predict(self, paths: list[Path]) -> str:
output_parts = [] # Use a list to collect file contents
for path in paths:
with open(path) as f:
output_parts.append(f.read())
return "".join(output_parts)
```
The corresponding cog command:
```bash
$ echo test1 > 1.txt
$ echo test2 > 2.txt
$ cog predict -i paths=@1.txt -i paths=@2.txt
Running prediction...
test1

test2
```
- Note the repeated inputs with the same name "paths" which constitute the list
5 changes: 3 additions & 2 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ func handleMultipleFileOutput(prediction *predict.Response, outputSchema *openap
}

func parseInputFlags(inputs []string) (predict.Inputs, error) {
keyVals := map[string]string{}
keyVals := map[string][]string{}
for _, input := range inputs {
var name, value string

Expand All @@ -305,7 +305,8 @@ func parseInputFlags(inputs []string) (predict.Inputs, error) {
value = value[1 : len(value)-1]
}

keyVals[name] = value
// Append new values to the slice associated with the key
keyVals[name] = append(keyVals[name], value)
}

return predict.NewInputs(keyVals), nil
Expand Down
79 changes: 57 additions & 22 deletions pkg/predict/input.go
Original file line number Diff line number Diff line change
@@ -1,41 +1,41 @@
package predict

import (
"fmt"
"mime"
"os"
"path/filepath"
"strings"

"github.com/mitchellh/go-homedir"
"github.com/vincent-petithory/dataurl"

"github.com/replicate/cog/pkg/util/console"
"github.com/replicate/cog/pkg/util/mime"
)

type Input struct {
String *string
File *string
Array *[]any
}

type Inputs map[string]Input

func NewInputs(keyVals map[string]string) Inputs {
func NewInputs(keyVals map[string][]string) Inputs {
input := Inputs{}
for key, val := range keyVals {
val := val
if strings.HasPrefix(val, "@") {
val = val[1:]
expandedVal, err := homedir.Expand(val)
if err != nil {
// FIXME: handle this better?
console.Warnf("Error expanding homedir: %s", err)
for key, vals := range keyVals {
if len(vals) == 1 {
val := vals[0]
if strings.HasPrefix(val, "@") {
val = val[1:]
input[key] = Input{File: &val}
} else {
val = expandedVal
input[key] = Input{String: &val}
}

input[key] = Input{File: &val}
} else {
input[key] = Input{String: &val}
} else if len(vals) > 1 {
var anyVals = make([]any, len(vals))
for i, v := range vals {
anyVals[i] = v
}
input[key] = Input{Array: &anyVals}
}
}
return input
Expand All @@ -55,19 +55,54 @@ func NewInputsWithBaseDir(keyVals map[string]string, baseDir string) Inputs {
return input
}

func (inputs *Inputs) toMap() (map[string]string, error) {
keyVals := map[string]string{}
func (inputs *Inputs) toMap() (map[string]any, error) {
keyVals := map[string]any{}
for key, input := range *inputs {
if input.String != nil {
// Directly assign the string value
keyVals[key] = *input.String
} else if input.File != nil {
content, err := os.ReadFile(*input.File)
// Single file handling: read content and convert to a data URL
dataURL, err := fileToDataURL(*input.File)
if err != nil {
return keyVals, err
}
mimeType := mime.TypeByExtension(filepath.Ext(*input.File))
keyVals[key] = dataurl.New(content, mimeType).String()
keyVals[key] = dataURL
} else if input.Array != nil {
// Handle array, potentially containing file paths
dataURLs := make([]string, len(*input.Array))
for i, elem := range *input.Array {
if str, ok := elem.(string); ok && strings.HasPrefix(str, "@") {
filePath := str[1:] // Remove '@' prefix
dataURL, err := fileToDataURL(filePath)
if err != nil {
return keyVals, err
}
dataURLs[i] = dataURL
} else if ok {
// Directly use the string if it's not a file path
dataURLs[i] = str
}
}
keyVals[key] = dataURLs
}
}
return keyVals, nil
}

// Helper function to read file content and convert to a data URL
func fileToDataURL(filePath string) (string, error) {
// Expand home directory if necessary
expandedVal, err := homedir.Expand(filePath)
if err != nil {
return "", fmt.Errorf("error expanding homedir for '%s': %w", filePath, err)
}

content, err := os.ReadFile(expandedVal)
if err != nil {
return "", err
}
mimeType := mime.TypeByExtension(filepath.Ext(expandedVal))
dataURL := dataurl.New(content, mimeType).String()
return dataURL, nil
}
2 changes: 1 addition & 1 deletion pkg/predict/predictor.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ type HealthcheckResponse struct {

type Request struct {
// TODO: could this be Inputs?
Input map[string]string `json:"input"`
Input map[string]interface{} `json:"input"`
}

type Response struct {
Expand Down
22 changes: 14 additions & 8 deletions python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,15 +382,21 @@ def _predict(
input_dict = initial_prediction["input"]

for k, v in input_dict.items():
if isinstance(v, types.URLPath):
try:
try:
# Check if v is an instance of URLPath
if isinstance(v, types.URLPath):
input_dict[k] = v.convert()
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("failed to download url path from input", exc_info=True)
return event_handler.response
# Check if v is a list of URLPath instances
elif isinstance(v, list) and all(
isinstance(item, types.URLPath) for item in v
):
input_dict[k] = [item.convert() for item in v]
except requests.exceptions.RequestException as e:
tb = traceback.format_exc()
event_handler.append_logs(tb)
event_handler.failed(error=str(e))
log.warn("Failed to download url path from input", exc_info=True)
return event_handler.response

for event in worker.predict(input_dict, poll=0.1):
if should_cancel.is_set():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from cog import BasePredictor, Path
from cog import BasePredictor, File


class Predictor(BasePredictor):
def predict(self, path: Path) -> str:
with open(path) as f:
return f.read()
def predict(self, file: File) -> str:
content = file.read()
if isinstance(content, bytes):
# Decode bytes to str assuming UTF-8 encoding; adjust if needed
content = content.decode('utf-8')
return content
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
text
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.11"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from cog import BasePredictor, File


class Predictor(BasePredictor):
def predict(self, files: list[File]) -> str:
output_parts = [] # Use a list to collect file contents
for f in files:
# Assuming file content is in bytes, decode to str before appending
content = f.read()
if isinstance(content, bytes):
# Decode bytes to str assuming UTF-8 encoding; adjust if needed
content = content.decode('utf-8')
output_parts.append(content)
return "\n\n".join(output_parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from cog import BasePredictor, Path


class Predictor(BasePredictor):
def predict(self, path: Path) -> str:
with open(path) as f:
return f.read()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.11"
predict: "predict.py:Predictor"
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from cog import BasePredictor, Path

class Predictor(BasePredictor):
def predict(self, paths: list[Path]) -> str:
output_parts = [] # Use a list to collect file contents
for path in paths:
with open(path) as f:
output_parts.append(f.read())
return "".join(output_parts)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
build:
python_version: "3.8"
predict: "predict.py:Predictor"
2 changes: 1 addition & 1 deletion test-integration/test_integration/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def predict(self, text: str) -> str:


def test_build_with_model(docker_image):
project_dir = Path(__file__).parent / "fixtures/file-project"
project_dir = Path(__file__).parent / "fixtures/path-project"
subprocess.run(
["cog", "build", "-t", docker_image],
cwd=project_dir,
Expand Down
29 changes: 25 additions & 4 deletions test-integration/test_integration/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_predict_takes_int_inputs_and_returns_ints_to_stdout():


def test_predict_takes_file_inputs(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-input-project"
project_dir = Path(__file__).parent / "fixtures/path-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "input.txt", "w") as fh:
Expand All @@ -46,7 +46,7 @@ def test_predict_takes_file_inputs(tmpdir_factory):


def test_predict_writes_files_to_files(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-output-project"
project_dir = Path(__file__).parent / "fixtures/path-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand All @@ -61,7 +61,7 @@ def test_predict_writes_files_to_files(tmpdir_factory):


def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-output-project"
project_dir = Path(__file__).parent / "fixtures/path-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand All @@ -76,7 +76,7 @@ def test_predict_writes_files_to_files_with_custom_name(tmpdir_factory):


def test_predict_writes_multiple_files_to_files(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/file-list-output-project"
project_dir = Path(__file__).parent / "fixtures/path-list-output-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
result = subprocess.run(
Expand Down Expand Up @@ -229,3 +229,24 @@ def test_predict_many_inputs_with_existing_image(docker_image, tmpdir_factory):
capture_output=True,
)
assert result.stdout.decode() == "hello default 20 world jpg foo 6\n"


def test_predict_path_list_input(tmpdir_factory):
project_dir = Path(__file__).parent / "fixtures/path-list-input-project"
out_dir = pathlib.Path(tmpdir_factory.mktemp("project"))
shutil.copytree(project_dir, out_dir, dirs_exist_ok=True)
with open(out_dir / "1.txt", "w") as fh:
fh.write("test1")
with open(out_dir / "2.txt", "w") as fh:
fh.write("test2")
cmd = ["cog", "predict", "-i", "paths=@1.txt", "-i", "paths=@2.txt"]

result = subprocess.run(
cmd,
cwd=out_dir,
check=True,
capture_output=True,
)
stdout = result.stdout.decode()
assert "test1" in stdout
assert "test2" in stdout

0 comments on commit 1289b6f

Please sign in to comment.